CN114360027A - 一种特征提取网络的训练方法、装置及电子设备 - Google Patents

一种特征提取网络的训练方法、装置及电子设备 Download PDF

Info

Publication number
CN114360027A
CN114360027A CN202210030447.8A CN202210030447A CN114360027A CN 114360027 A CN114360027 A CN 114360027A CN 202210030447 A CN202210030447 A CN 202210030447A CN 114360027 A CN114360027 A CN 114360027A
Authority
CN
China
Prior art keywords
class
category
target
image
sample image
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210030447.8A
Other languages
English (en)
Inventor
李弼
彭楠
希滕
张刚
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202210030447.8A priority Critical patent/CN114360027A/zh
Publication of CN114360027A publication Critical patent/CN114360027A/zh
Pending legal-status Critical Current

Links

Images

Abstract

本公开提供了一种特征提取网络的训练方法、装置及电子设备,涉及数据处理技术领域,尤其涉及人工智能模型训练领域。具体实现方案为:获取训练集中的样本图像;将样本图像输入分类模型,得到模型输出结果和样本图像的目标图像特征;识别所述目标图像特征和指定特征的相似性,得到识别结果;利用模型输出结果与识别结果,生成样本图像的类别预测结果;基于类别预测结果和样本图像的类别标签,对分类模型的模型参数进行调整,并返回获取训练集中样本图像的步骤,直至模型收敛,得到训练完成的特征提取网络。通过本方案,可以在降低资源占用的同时保证特征提取网络的准确度。

Description

一种特征提取网络的训练方法、装置及电子设备
技术领域
本公开涉及数据处理技术领域,尤其涉及人工智能模型训练领域,具体涉及一种特征提取网络的训练方法、装置及电子设备。
背景技术
针对对象识别场景而言,通常利用特征提取网络,提取待识别图像的图像特征,然后,识别所提取图像特征与底库中各个图像特征的相似性,并基于识别出的相似性,确定出待识别图像中的对象。其中,底库中每一图像特征为针对一对象的图像特征。
相关技术中,训练用于进行对象分类的分类模型,待模型收敛后,将分类模型中的特征提取网络,确定为训练完成的特征提取网络,以用于后续的对象识别场景。
发明内容
本公开提供了一种特征提取网络的训练方法、装置及电子设备。
根据本公开的一方面,提供了一种特征提取网络的训练方法,包括:
获取训练集中的样本图像;
将所述样本图像输入分类模型,得到模型输出结果和所述样本图像的目标图像特征;其中,所述分类模型包括特征提取网络和全连接层,所述分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别;
识别所述目标图像特征和指定特征的相似性,得到识别结果;其中,所述指定特征为属于辅助类别的图像的图像特征,各个目标类别和辅助类别中存在所述样本图像所属的类别;
利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果;
基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,并返回所述获取训练集中样本图像的步骤,直至模型收敛,得到训练完成的特征提取网络。
根据本公开的另一方面,提供了一种特征提取网络的训练装置,包括:
获取模块,用于获取训练集中的样本图像;
输入模块,用于将所述样本图像输入分类模型,得到模型输出结果和所述样本图像的目标图像特征;其中,所述分类模型包括特征提取网络和全连接层,所述分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别;
识别模块,用于识别所述目标图像特征和指定特征的相似性,得到识别结果;其中,所述指定特征为属于辅助类别的图像的图像特征,各个目标类别和辅助类别中存在所述样本图像所属的类别;
生成模块,用于利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果;
调整模块,用于基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,并返回所述获取训练集中样本图像的步骤,直至模型收敛,得到训练完成的特征提取网络。
根据本公开的另一方面,提供了一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行上述特征提取网络的训练方法。
根据本公开的另一方面,提供一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行上述特征提取网络的训练方法。
根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现上述特征提取网络的训练方法。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是本公开实施例提供的特征提取网络的训练方法的流程图;
图2是本公开实施例提供的特征提取网络的训练方法的另一流程图;
图3是本公开实施例提供的特征提取网络的训练装置的结构图;
图4是本公开实施例所提供的电子设备的结构示意图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
在对象识别场景中,如识别目标图像中人脸的身份信息,识别图像中动物的所属物种等,一般要通过特征提取网络对目标图像进行影像分析和变换,以提取图像特征;再比较提取的图像特征与底库中各个图像特征的相似性;当与底库中的某个类别的相似性达到阈值,可以认为该目标图像中的对象与该类别匹配。
相关技术中,针对特征提取网络的训练,训练集中每个类别都用全连接层中的一列向量表示,这样,通过全连接层建模的类别包含训练集中的所有类别,其中,每一种识别对象可以看做一个类别,如把每一个人当成一个类别。并且,经过特征提取网络提取的目标图像特征,与全连接层中的各个类别对应向量进行距离比对;当该目标图像特征与全连接层中的某一类别的向量计算得到的距离小于阈值,则确定该目标图像特征对应的样本图像属于该类别,得到分类的结果。其中,所使用的距离比对方式有多种,如欧式距离、余弦距离等。
这里的全连接层承担着聚集类内特征的作用,可以建模多样性较高或者包含一定噪声的类别,能够拉开类间距离,指引特征的走向,从而让特征提取网络更快收敛。但是这种方法由于要建模训练所有类别,会占用较多的计算资源,以及存储资源,例如:内存资源。
此外,特征提取网络还可以基于动态队列来进行训练。所谓动态队列,是用训练集中一个类别的样本特征作为该类别的代表,放入到队列中。队列中的类别动态变化的,且只包含训练集中所有样本的一部分类别,训练时比对特征和特征之间的距离。这种方法计算损失时不使用所有类别,能够有效减低存储资源和计算资源的占用,且训练任务和实际测试过程更加贴近。但是因为这种方式是用单个样本代表类别中心,在类别多样性较高或者噪声较多时网络较难收敛。
基于上述内容,在训练特征提取网络时,如何在降低资源占用的同时保证特征提取网络的准确度,是个亟待解决的问题。
针对上述问题,本公开提供了一种特征提取网络的训练方法、装置及电子设备。
下面首先对本公开实施例所提供的一种特征提取网络的训练方法进行介绍。
其中,本公开实施例所提供的一种特征提取网络的训练方法,应用于电子设备中。在实际应用中,该电子设备可以为服务器或终端设备,这都是合理的。
本公开实施例提供的一种特征提取网络的训练方法,可以包括以下步骤:
获取训练集中的样本图像;
将所述样本图像输入分类模型,得到模型输出结果和所述样本图像的目标图像特征;其中,所述分类模型包括特征提取网络和全连接层,所述分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别;
识别所述目标图像特征和指定特征的相似性,得到识别结果;其中,所述指定特征为属于辅助类别的图像的图像特征,各个目标类别和辅助类别中存在所述样本图像所属的类别;
利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果;
基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,并返回所述获取训练集中样本图像的步骤,直至模型收敛,得到训练完成的特征提取网络。
本实施例中,将样本图像输入分类模型,得到模型输出结果和样本图像的目标图像特征;分类模型的全连接层用于建模各个目标类别,各个目标类别为训练集所覆盖的部分类别;识别目标图像特征和指定特征的相似性,得到识别结果;指定特征为属于辅助类别的图像的图像特征;再利用模型输出结果与识别结果,生成样本图像的类别预测结果,最后基于类别预测结果和样本图像的类别标签对分类模型的模型参数进行调整,直至模型收敛,得到训练完成的特征提取网络。可见,本方案中,全连接层只对部分类别建模,使得计算资源和存储资源的占用大大降低,并通过辅助类别来保证对样本图像类别的覆盖,以保证特征提取网络的准确度。因此,通过本方案可以在降低资源占用的同时保证特征提取网络的准确度。
下面结合附图,对本公开所提供的一种特征提取网络的训练方法进行介绍。
如图1所示,本公开实施例所提供的一种特征提取网络的训练方法,可以包括如下步骤:
S101,获取训练集中的样本图像;
对象识别场景中,特征提取网络的训练一般是有监督的训练方式,即将样本图像及其对应的类别标签作为训练集,并将样本图片输入包含特征提取网络的分类模型中,得到类别预测值,基于类别预测值和类别标签调整参数,直到模型收敛,完成特征提取网络的训练。
本公开所提供的方案中,在训练特征提取网络的过程中,可以获取训练集中的至少一个样本图像,基于所获取的样本图像进行特征提取网络的训练。其中,训练集中样本图像的类别,与特征提取网络所需应用于的对象识别场景相关,例如:若对象识别场景为人脸识别场景,则训练集中的样本图像为人脸图像;若对象识别场景为动物识别场景,则训练集中的样本图像为动物图像。需要说明的是,本公开所涉及的训练集可以为公开训练集,或者,由从授权用户处获取的样本图像所构成训练集。
S102,将所述样本图像输入分类模型,得到模型输出结果和所述样本图像的目标图像特征;其中,所述分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别;
本公开中的分类模型包括特征提取网络和全连接层,其中,特征提取网络用于对输入图像进行特征提取,全连接层用于基于提取的特征对输入图像进行类别分类,本公开不对特征提取网络和全连接层的具体结构进行限定。
为了在降低资源占用,本公开提供的方案中,分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别,也就是,全连接层中针对训练集的部分类别,通过一列向量值来表示。部分类别的选择有多种可能的方式,示例性的,可以随机选择;也可以根据类别所包含的样本数进行选择,如在按照所包含样本数,对各个类别排序后,选择排位在前的K个类别;还可以根据类别的多样性进行选择,如在按照所包含样本图像的多样性高低,对各个类别进行排序后,选择排位在前的K个类别,等等。其中,选取部分类别时所依据的多样性可以指关于人脸角度的多样性,或者,关于年龄的多样性,本方案对此不做限制。
进而,将获取的样本图像输入该分类模型中,得到模型输出结果即该分类模型分类的结果,和特征提取网络所提取的该样本图像的目标图像特征。其中,模型输出结果表征了该样本图像的类别是否是全连接层中的某一类别。示例性的,其中,模型输出结果具体可以表征所述目标图像特征与针对各个目标类别的图像特征的相似性。
S103,识别所述目标图像特征和指定特征的相似性,得到识别结果;其中,所述指定特征为属于辅助类别的图像的图像特征,各个目标类别和辅助类别中存在所述样本图像所属的类别;
由于本方案中的全连接层所建模的类别,没有覆盖训练集所涉及的所有类别,而针对没有覆盖到的类别,该分类模型的全连接层无法进行有效的分类,因此,为了保证特征提取网络的准确度,本方案增加关于辅助类别的图像特征与目标图像特征相似分析。这样,保证各个目标类别加上辅助类别能够始终覆盖该样本图像的类别,以及后续可以得到该样本图像的有效的类别预测结果。
其中,所述目标图像特征和指定特征的相似性的识别过程,可以通过目标图像特征与指定特征进行距离比较得到。
并且,识别结果为目标样本图像与辅助图像的相似性,通过相似性可以表征该样本图像的类别是辅助类别中的哪个类别。可以理解的是,辅助类别中可能存在与全连接层中相同的类别;并且,辅助类别的数量可以是一个或多个,也就是,指定特征的数量可以是一个或多个。
在一种实现方式中,为了保证各个目标类别和辅助类别中存在所述样本图像所属的类别,可以预先构建多组指定特征,若样本图像的类别不属于任一目标类别,则将目标图像特征与目标组内的指定特征进行相似性比对,其中,目标组中包含关于该样本图像的类别的指定特征;若样本图像的类别属于任一目标类别,可以随意选择一组指定特征,将目标图像特征与所选择的一组指定特征进行相似性比对;当然,若样本图像的类别不属于任一目标类别,且多组指定特征中也不存在关于该样本图像的类别的特征,则可以获取样本图像对应的参考图像的图像特征,参考图像与所述样本图像的类别相同,并将参考图像的图像特征替换到某一组内一个特征,将目标图像特征与包含参考图像的图像特征的一组指定特征进行相似性比对。其中,指定特征以及参考图像的图像特征可以由分类模型中的特征提取网络所提取,也可以由其他特征提取网络所提取,这都是合理的。
另外,需要说明的是,为了保证模型训练的有效性,各个目标类别和辅助类别的类别总数是固定的,这样使得后续每次生成的类别预测值是关于固定数量的类别的预测值。
S104,利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果;
可以理解的是,类别预测结果表征:该样本图像是全连接层及辅助类别中的哪个类别。该样本图像的类别可能既在全连接层所建模的各个目标类别中,也在辅助类别中,当然,该样本图像的类别可能仅仅在辅助类别中,或者,仅仅在各个目标类别中。
其中,所述利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果,包括:
对所述模型输出结果和识别结果进行融合,得到所述样本图像的类别预测结果。
由于各个目标类别和辅助类别中存在该样本图像所属的类别,因此,可以对模型输出结果和识别结果进行融合,得到样本图像的类别预测结果,从而保证得到类别预测结果的有效性,即类别预测结果所涉及的类别中包含该样本图像的类别。
示例性的,模型输出结果和所述识别结果均为向量形式;
所述对所述模型输出结果和识别结果进行融合,得到所述样本图像的类别预测结果,包括:
将所述模型输出结果和识别结果进行向量拼接,得到所述样本图像的类别预测结果。
各个目标类别中的每一目标类别对应模型输出结果的一个维度,辅助类别中的每一类别对应识别结果的一个维度。可以理解的是,各个目标类别的类别数量,与模型输出结果的向量维度相同;辅助类别的类别数量与识别结果的向量维度相同。例如,该分类模型的全连接层建模了1000个目标类别,辅助类别的数量是100,则模型输出结果的向量维度是1000,识别结果的向量维度是100。
并且,将这两个向量进行拼接,得到向量形式的类别预测结果,类别预测结果的向量维度为模型输出结果和识别结果向量的维度之和。
需要强调的是,本公开并不对模型输出结果和识别结果的具体形式进行限定;不同的具体形式所对应的融合方式可以不同。
S105,基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,并返回所述获取训练集中样本图像的步骤,直至模型收敛,得到训练完成的特征提取网络。
通过不断获取样本图像的类别预测结果和该样本图像的类别标签,可以确定模型的损失值,若基于损失值,识别出分类模型未收敛,则不断调整模型参数,并返回所述获取训练集中样本图像的步骤,直至模型收敛,其中,在模型收敛后,分类模型中的特征提取网络为训练完成的特征提取网络。示例性的,若计算得到的损失值大于预定阈值,则表明分类模型未收敛。
可以理解的是,这里模型参数包括特征提取网络的参数,和全连接层中的参数,在调整特征提取网络参数的同时,也可以调整全连接层中的参数,从而让全连接层有更好的分类能力。本公开并不对计算模型损失所利用的损失函数进行限定。另外,示例性的,在这里对模型的参数调整可以使用梯度下降法,当然并不局限于此。
本实施例中,将样本图像输入分类模型,得到模型输出结果和样本图像的目标图像特征;分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别;识别目标图像特征和指定特征的相似性,得到识别结果;指定特征为属于辅助类别的图像的图像特征;再利用模型输出结果与识别结果,生成样本图像的类别预测结果,最后基于类别预测结果和样本图像的类别标签对分类模型的模型参数进行调整,直至模型收敛,得到训练完成的特征提取网络。可见,本方案中,全连接层只对部分类别建模,使得计算资源和存储资源的占用大大降低,并通过辅助类别来保证对样本图像类别的覆盖,以保证训练得到的特征提取网络的准确度。因此,通过本方案可以在降低资源占用的同时保证特征提取网络的准确度。
可选地,在本公开的另一实施例中,所述基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,包括步骤A1-A2:
步骤A1:将所述样本图像的类别标签进行形式转换,得到与所述类别预测结果形式相同的类别真值;
为了保证类别预测结果与样本图像的类别标签的可比性,可以根据类别预测结果的形式,将样本图像的类别标签进行形式转换,得到与所述类别预测结果形式相同的类别真值,以便后续的模型损失的计算。
示例性的,若类别预测结果为向量形式,且各个目标类别和辅助类别中的每一类别对应一维度;
所述将所述样本图像的类别标签进行形式转换,得到与所述类别预测结果形式相同的类别真值,可以包括:
将所述样本图像的类别标签转换为具有目标维数的向量值,作为类别真值;
其中,所述目标维数与所述类别预测结果的向量维数相同,且所述向量值的每一维度对应的类别与所述类别预测结果中相应维度相同。
例如,该分类模型的全连接层建模了1000个目标类别,辅助类别的数量是100,那么可以得到该类别真值对应的向量和类别预测结果的向量维数都是1100。且这两个向量的每一列所对应的类别相同,例如,第一列都用于表示对象A,第二列都用于表示对象B。
并且,需要说明的是,该类别真值能够表征出所述样本图像的类别标签所表示的类别,即通过该类别真值可以识别出样本图像对应的类别。
可选的,在一种实现方式中,所述将所述样本图像的类别标签转换为具有目标维数的向量值,作为类别真值,包括步骤A11-A13:
步骤A11,构建具有目标维数的初始向量值;
即,构建与类别预测结果的向量维度相同的初始向量值。
步骤A12,若所述样本图像的类别标签所指示的类别属于任一目标类别,将所述初始向量值中的第一指定维度设置为第一数值,其余维度设置为第二数值;其中,所述第一指定维度为各个目标类别所对应的维度中,所述类别标签所指示的类别;
也就是,该样本图像的类别被分类模型中的全连接层建模的各个目标类别所覆盖,那么设置该初始向量值中对应的维度位置的值为第一数值,其余维度位置的值设置为第二数值。
步骤A13,若所述类别标签所指示的类别不属于任一目标类别,将所述初始向量值中的第二指定维度设置为第一数值,其余维度设置为第二数值;其中,所述第二指维度为所述辅助类别所对应的维度中,所述类别标签所指示的类别对应的维度。
也就是,该样本图像类别没有被分类模型中的全连接层建模的各个目标类别所覆盖,设置该初始向量值中,与辅助类别所对应的维度位置的值为第一数值,其余维度位置的值设置为第二数值。
所得到的预测结果的向量,也可以根据同样的方法设置。
在一种实现方式中,初始向量值和对应的预测结果的向量,都可以为one-hot向量,按照上述方式,将对应维度的位置设置为1,其余设置为0。
步骤A2:基于所述类别预测结果和所述类别真值的差异,对所述分类模型的模型参数进行调整。
其中,基于类别预测结果和类别真值的差异,可以计算分类模型的损失值,若基于损失值判定出分类模型未收敛,则对分类模型的模型参数进行调整。本公开并不对模型损失值计算时所利用的损失函数进行限定。
本实施例中,通过将样本图像的类别标签进行形式转换,得到与类别预测结果形式相同的类别真值;基于类别预测结果和所述类别真值的差异,对分类模型的模型参数进行调整。可见,通过本方案,保证了类别标签和的类别预测结果可比性。
可选地,在另一实施例中,所述识别所述目标图像特征和指定特征的相似性,得到识别结果,包括:
识别所述目标图像特征和动态队列中的特征的相似性,得到识别结果;其中,所述动态队列中的特征为指定特征。
在这种实现方式中,将属于辅助类别的图像的图像特征构造为一个动态队列,此时,动态队列中的特征为指定特征,且动态队列中的特征的表征形式也可以是向量。通过识别目标图像特征和动态队列中的特征的相似性,得到识别结果。
其中,所述动态队列具有指定队列长度;
所述识别所述目标图像特征和动态队列中的特征的相似性,得到识别结果之前,还可以包括步骤B1-B2:
步骤B1,获取所述样本图像对应的参考图像的图像特征;其中,所述参考图像与所述样本图像的类别相同;
为了获取样本图像对应的参考图像的图像特征,可以将该样本图像本身看作参考图像,并通过另一个特征提取网络提取图像特征;也可以将与该样本图像类别相同的另一张图像作为参考图像,输入同一特征提取网络或另一特征提取网络得到图像特征,这都是可以的。
步骤B2,利用所述参考图像的图像特征,更新所述动态队列。
动态队列是一个不断变化且长度固定的队列,此时将该参考图像的图像特征放入动态队列中,要从动态队列中选择一个特征移出队列,由此保证动态队列长度不变,如何选择移出的特征有多种可能的方式,如选择最早入列的特征或者样本数最多的特征,具体方式可以根据情况设定,本公开不做具体限定。
通过这种方法,就可以保证动态队列中始终有一个指定特征与样本图像的类别相匹配,即保证对训练集所有类别的覆盖。此时,识别结果表征了,该样本图像属于该动态队列中的哪一个类别。
本实施例中,将指定特征构建成一个动态队列,识别目标图像特征和动态队列中的特征的相似性,可以有效保证样本图像的类别始终被覆盖到,从而在降低资源占用的同时保证特征提取网络的准确度。另外,相对于上述通过动态队列进行网络训练的现有方案而言,本方案可以通过全连接层来建模多样性较高或者样本数较多的类别,加快了模型收敛,提升了模型效果。
为了方便理解,下面结合图2所示的原理图,针对一个类别的训练过程,对本公开实施例所提供的一种特征提取网络的训练方法进行介绍。
首先,获取同一个类别的2张图像,图像A和图像B,经过特征提取网络后得到图像A的特征x和图像B的特征x’。其中,x即为目标图像特征,图像A通过分类模型的特征提取网络可以得到特征x;这两张图片可以经过同一个特征提取网络或者不同的特征提取网络,来得到相应的图像特征。
然后,如图2所示,将特征x’放入动态队列,同时从动态队列中选择一个特征移出队列,由此保证动态队列长度不变。
并且,如图2所示,分类模型中的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别。
再次,将特征x通过全连接层,可以得到模型输出结果,即与全连接层中各个类别对应的特征的相似性;同时识别x和动态队列中的每一特征的相似性,可以得到识别结果;将模型输出结果和识别结果融合后,得到类别预测结果。这里的类别预测结果可以用one-hot向量来表征。
最后,基于类别预测结果与类别标签的one-hot向量,计算损失值;若基于损失值判定出分类模型未收敛,通过梯度更新该特征提取网络和全连接层的参数,返回重新获取两张图片的过程;直至模型收敛,得到训练完成的特征提取网络。
其中,这里x的类别标签,即图像A的类别标签,用one-hot向量表示,且维度数量与由全连接层所建模的目标类别以及动态队列所涉及的辅助类别的数量之和。并且,onehot分2种情况:当x的类别标签在全连接层所建模的各个目标类别中时,one-hot向量对应位置置为1,其他位置置为0;当x的类别标签不在部分全连接层中时,one-hot向量动态队列对应位置置为1,其他位置置为0。
可见,本实施例中,全连接层只对部分类别建模,并通过动态队列来保证对样本图像类别的覆盖,可以在降低资源占用的同时保证特征提取网络的准确度。
本公开还提供一种特征提取网络的训练装置,如图3所示,该装置包括:
获取模块310,用于获取训练集中的样本图像;
输入模块320,用于将所述样本图像输入分类模型,得到模型输出结果和所述样本图像的目标图像特征;其中,所述分类模型包括特征提取网络和全连接层,所述分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别;
识别模块330,用于识别所述目标图像特征和指定特征的相似性,得到识别结果;其中,所述指定特征为属于辅助类别的图像的图像特征,各个目标类别和辅助类别中存在所述样本图像所属的类别;
生成模块340,用于利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果;
调整模块350,用于基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,并返回所述获取训练集中样本图像的步骤,直至模型收敛,得到训练完成的特征提取网络。
可选地,所述识别模块330具体用于:
识别所述目标图像特征和动态队列中的特征的相似性,得到识别结果;其中,所述动态队列中的特征为指定特征。
可选地,所述动态队列具有指定队列长度;
所述装置还包括:
第二获取模块,用于在识别模块识别所述目标图像特征和动态队列中的特征的相似性,得到识别结果之前,获取所述样本图像对应的参考图像的图像特征;其中,所述参考图像与所述样本图像的类别相同;
更新模块,用于利用所述参考图像的图像特征,更新所述动态队列。
可选地,所述模型输出结果表征所述目标图像特征与针对各目标类别的图像特征的相似性;
生成模块340,具体用于:
对所述模型输出结果和识别结果进行融合,得到所述样本图像的类别预测结果。
可选地,所述模型输出结果和所述识别结果均为向量形式;
生成模块340,对所述模型输出结果和识别结果进行融合,得到所述样本图像的类别预测结果,包括:
将所述模型输出结果和识别结果进行向量拼接,得到所述样本图像的类别预测结果。
可选地,调整模块350,包括:
转换子模块,用于将所述样本图像的类别标签进行形式转换,得到与所述类别预测结果形式相同的类别真值;
调整子模块,用于基于所述类别预测结果和所述类别真值的差异,对所述分类模型的模型参数进行调整。
可选地,所述类别预测结果为向量形式,且各个目标类别和辅助类别中的每一类别对应一维度;
转换子模块,具体用于:
将所述样本图像的类别标签转换为具有目标维数的向量值,作为类别真值;
其中,所述目标维数与所述类别预测结果的向量维数相同,且所述向量值的每一维度对应的类别与所述类别预测结果中相应维度相同。
可选地,所述转换子模块将所述样本图像的类别标签转换为具有目标维数的向量值,作为类别真值,包括:
构建具有目标维数的初始向量值;
若所述样本图像的类别标签所指示的类别属于任一目标类别,将所述初始向量值中的第一指定维度设置为第一数值,其余维度设置为第二数值;其中,所述第一指定维度为各个目标类别所对应的维度中,所述类别标签所指示的类别对应的维度;
若所述类别标签所指示的类别不属于任一目标类别,将所述初始向量值中的第二指定维度设置为第一数值,其余维度设置为第二数值;其中,所述第二指维度为所述辅助类别所对应的维度中,所述类别标签所指示的类别对应的维度。
本公开的技术方案中,所涉及的用户个人信息的收集、存储、使用、加工、传输、提供和公开等处理,均符合相关法律法规的规定,且不违背公序良俗。
需要说明的是,本实施例中的训练集来自于公开数据集。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图4示出了可以用来实施本公开的实施例的示例电子设备400的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图4所示,设备400包括计算单元401,其可以根据存储在只读存储器(ROM)402中的计算机程序或者从存储单元408加载到随机访问存储器(RAM)403中的计算机程序,来执行各种适当的动作和处理。在RAM 403中,还可存储设备400操作所需的各种程序和数据。计算单元401、ROM 402以及RAM 403通过总线404彼此相连。输入/输出(I/O)接口405也连接至总线404。
设备400中的多个部件连接至I/O接口405,包括:输入单元406,例如键盘、鼠标等;输出单元407,例如各种类型的显示器、扬声器等;存储单元408,例如磁盘、光盘等;以及通信单元409,例如网卡、调制解调器、无线通信收发机等。通信单元409允许设备400通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元401可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元401的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元401执行上文所描述的特征提取网络的训练方法。例如,在一些实施例中,特征提取网络的训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元408。在一些实施例中,计算机程序的部分或者全部可以经由ROM 402和/或通信单元409而被载入和/或安装到设备400上。当计算机程序加载到RAM 403并由计算单元401执行时,可以执行上文描述的特征提取网络的训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元401可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行上述特征提取网络的训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、复杂可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。

Claims (19)

1.一种特征提取网络的训练方法,包括:
获取训练集中的样本图像;
将所述样本图像输入分类模型,得到模型输出结果和所述样本图像的目标图像特征;其中,所述分类模型包括特征提取网络和全连接层,所述分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别;
识别所述目标图像特征和指定特征的相似性,得到识别结果;其中,所述指定特征为属于辅助类别的图像的图像特征,各个目标类别和辅助类别中存在所述样本图像所属的类别;
利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果;
基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,并返回所述获取训练集中样本图像的步骤,直至模型收敛,得到训练完成的特征提取网络。
2.根据权利要求1所述的方法,其中,所述识别所述目标图像特征和指定特征的相似性,得到识别结果,包括:
识别所述目标图像特征和动态队列中的特征的相似性,得到识别结果;其中,所述动态队列中的特征为指定特征。
3.根据权利要求2所述的方法,其中,所述动态队列具有指定队列长度;
所述识别所述目标图像特征和动态队列中的特征的相似性,得到识别结果之前,还包括:
获取所述样本图像对应的参考图像的图像特征;其中,所述参考图像与所述样本图像的类别相同;
利用所述参考图像的图像特征,更新所述动态队列。
4.根据权利要求1-3任一项所述的方法,其中,所述模型输出结果表征所述目标图像特征与针对各目标类别的图像特征的相似性;
所述利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果,包括:
对所述模型输出结果和识别结果进行融合,得到所述样本图像的类别预测结果。
5.根据权利要求4所述的方法,其中,所述模型输出结果和所述识别结果均为向量形式;
所述对所述模型输出结果和识别结果进行融合,得到所述样本图像的类别预测结果,包括:
将所述模型输出结果和识别结果进行向量拼接,得到所述样本图像的类别预测结果。
6.根据权利要求1-3任一项所述的方法,其中,所述基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,包括:
将所述样本图像的类别标签进行形式转换,得到与所述类别预测结果形式相同的类别真值;
基于所述类别预测结果和所述类别真值的差异,对所述分类模型的模型参数进行调整。
7.根据权利要求6所述的方法,其中,所述类别预测结果为向量形式,且各个目标类别和辅助类别中的每一类别对应一维度;
所述将所述样本图像的类别标签进行形式转换,得到与所述类别预测结果形式相同的类别真值,包括:
将所述样本图像的类别标签转换为具有目标维数的向量值,作为类别真值;
其中,所述目标维数与所述类别预测结果的向量维数相同,且所述向量值的每一维度对应的类别与所述类别预测结果中相应维度相同。
8.根据权利要求7所述的方法,其中,所述将所述样本图像的类别标签转换为具有目标维数的向量值,作为类别真值,包括:
构建具有目标维数的初始向量值;
若所述样本图像的类别标签所指示的类别属于任一目标类别,将所述初始向量值中的第一指定维度设置为第一数值,其余维度设置为第二数值;其中,所述第一指定维度为各个目标类别所对应的维度中,所述类别标签所指示的类别对应的维度;
若所述类别标签所指示的类别不属于任一目标类别,将所述初始向量值中的第二指定维度设置为第一数值,其余维度设置为第二数值;其中,所述第二指维度为所述辅助类别所对应的维度中,所述类别标签所指示的类别对应的维度。
9.一种特征提取网络的训练装置,包括:
获取模块,用于获取训练集中的样本图像;
输入模块,用于将所述样本图像输入分类模型,得到模型输出结果和所述样本图像的目标图像特征;其中,所述分类模型包括特征提取网络和全连接层,所述分类模型的全连接层用于建模各个目标类别,各个目标类别为所述训练集所覆盖的部分类别;
识别模块,用于识别所述目标图像特征和指定特征的相似性,得到识别结果;其中,所述指定特征为属于辅助类别的图像的图像特征,各个目标类别和辅助类别中存在所述样本图像所属的类别;
生成模块,用于利用所述模型输出结果与识别结果,生成所述样本图像的类别预测结果;
调整模块,用于基于所述类别预测结果和所述样本图像的类别标签,对所述分类模型的模型参数进行调整,并返回所述获取训练集中样本图像的步骤,直至模型收敛,得到训练完成的特征提取网络。
10.根据权利要求9所述的装置,其中,所述识别模块具体用于:
识别所述目标图像特征和动态队列中的特征的相似性,得到识别结果;其中,所述动态队列中的特征为指定特征。
11.根据权利要求9所述的装置,其中,所述动态队列具有指定队列长度;
所述装置还包括:
第二获取模块,用于在识别模块识别所述目标图像特征和动态队列中的特征的相似性,得到识别结果之前,获取所述样本图像对应的参考图像的图像特征;其中,所述参考图像与所述样本图像的类别相同;
更新模块,用于利用所述参考图像的图像特征,更新所述动态队列。
12.根据权利要求9-11任一项所述的装置,其中,所述模型输出结果表征所述目标图像特征与针对各目标类别的图像特征的相似性;
所述生成模块,具体用于:
对所述模型输出结果和识别结果进行融合,得到所述样本图像的类别预测结果。
13.根据权利要求12所述的装置,其中,所述模型输出结果和所述识别结果均为向量形式;
所述生成模块,对所述模型输出结果和识别结果进行融合,得到所述样本图像的类别预测结果,包括:
将所述模型输出结果和识别结果进行向量拼接,得到所述样本图像的类别预测结果。
14.根据权利要求9-11任一项所述的装置,其中,所述调整模块,包括:
转换子模块,用于将所述样本图像的类别标签进行形式转换,得到与所述类别预测结果形式相同的类别真值;
调整子模块,用于基于所述类别预测结果和所述类别真值的差异,对所述分类模型的模型参数进行调整。
15.根据权利要求14所述的装置,其中,所述类别预测结果为向量形式,且各个目标类别和辅助类别中的每一类别对应一维度;
所述转换子模块,具体用于:
将所述样本图像的类别标签转换为具有目标维数的向量值,作为类别真值;
其中,所述目标维数与所述类别预测结果的向量维数相同,且所述向量值的每一维度对应的类别与所述类别预测结果中相应维度相同。
16.根据权利要求15所述的装置,其中,所述转换子模块将所述样本图像的类别标签转换为具有目标维数的向量值,作为类别真值,包括:
构建具有目标维数的初始向量值;
若所述样本图像的类别标签所指示的类别属于任一目标类别,将所述初始向量值中的第一指定维度设置为第一数值,其余维度设置为第二数值;其中,所述第一指定维度为各个目标类别所对应的维度中,所述类别标签所指示的类别对应的维度;
若所述类别标签所指示的类别不属于任一目标类别,将所述初始向量值中的第二指定维度设置为第一数值,其余维度设置为第二数值;其中,所述第二指维度为所述辅助类别所对应的维度中,所述类别标签所指示的类别对应的维度。
17.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-8中任一项所述的方法。
18.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1-8中任一项所述的方法。
19.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1-8中任一项所述的方法。
CN202210030447.8A 2022-01-12 2022-01-12 一种特征提取网络的训练方法、装置及电子设备 Pending CN114360027A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210030447.8A CN114360027A (zh) 2022-01-12 2022-01-12 一种特征提取网络的训练方法、装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210030447.8A CN114360027A (zh) 2022-01-12 2022-01-12 一种特征提取网络的训练方法、装置及电子设备

Publications (1)

Publication Number Publication Date
CN114360027A true CN114360027A (zh) 2022-04-15

Family

ID=81109225

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210030447.8A Pending CN114360027A (zh) 2022-01-12 2022-01-12 一种特征提取网络的训练方法、装置及电子设备

Country Status (1)

Country Link
CN (1) CN114360027A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115147679A (zh) * 2022-06-30 2022-10-04 北京百度网讯科技有限公司 多模态图像识别方法和装置、模型训练方法和装置
CN115578584A (zh) * 2022-09-30 2023-01-06 北京百度网讯科技有限公司 图像处理方法、图像处理模型的构建和训练方法

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115147679A (zh) * 2022-06-30 2022-10-04 北京百度网讯科技有限公司 多模态图像识别方法和装置、模型训练方法和装置
CN115147679B (zh) * 2022-06-30 2023-11-14 北京百度网讯科技有限公司 多模态图像识别方法和装置、模型训练方法和装置
CN115578584A (zh) * 2022-09-30 2023-01-06 北京百度网讯科技有限公司 图像处理方法、图像处理模型的构建和训练方法
CN115578584B (zh) * 2022-09-30 2023-08-29 北京百度网讯科技有限公司 图像处理方法、图像处理模型的构建和训练方法

Similar Documents

Publication Publication Date Title
CN113657465B (zh) 预训练模型的生成方法、装置、电子设备和存储介质
CN108898186B (zh) 用于提取图像的方法和装置
CN112749344B (zh) 信息推荐方法、装置、电子设备、存储介质及程序产品
CN114360027A (zh) 一种特征提取网络的训练方法、装置及电子设备
CN114494784A (zh) 深度学习模型的训练方法、图像处理方法和对象识别方法
CN113627536B (zh) 模型训练、视频分类方法,装置,设备以及存储介质
CN113780098A (zh) 文字识别方法、装置、电子设备以及存储介质
CN113961765B (zh) 基于神经网络模型的搜索方法、装置、设备和介质
CN113516185B (zh) 模型训练的方法、装置、电子设备及存储介质
CN115081630A (zh) 多任务模型的训练方法、信息推荐方法、装置和设备
CN114511064A (zh) 神经网络模型的解释方法、装置、电子设备和存储介质
CN113408632A (zh) 提高图像分类准确性的方法、装置、电子设备及存储介质
CN114120416A (zh) 模型训练方法、装置、电子设备及介质
CN113032251A (zh) 应用程序服务质量的确定方法、设备和存储介质
CN113642495B (zh) 用于评价时序提名的模型的训练方法、设备、程序产品
CN117131197B (zh) 一种招标书的需求类别处理方法、装置、设备及存储介质
CN114693950B (zh) 一种图像特征提取网络的训练方法、装置及电子设备
US20220383626A1 (en) Image processing method, model training method, relevant devices and electronic device
CN115760864A (zh) 图像分割方法、装置、电子设备及存储介质
CN115660363A (zh) 对话处理方法、装置、电子设备和存储介质
CN117651167A (zh) 资源推荐方法、装置、设备以及存储介质
CN114912541A (zh) 分类方法、装置、电子设备和存储介质
CN117093498A (zh) 测试用例编排的方法及装置、电子设备和存储介质
CN115862046A (zh) 样本生成方法、文档理解模型的训练方法和文档理解方法
CN114882309A (zh) 目标再识别模型的训练与目标再识别方法、装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination