CN114548288A - 模型训练、图像识别方法和装置 - Google Patents
模型训练、图像识别方法和装置 Download PDFInfo
- Publication number
- CN114548288A CN114548288A CN202210171304.9A CN202210171304A CN114548288A CN 114548288 A CN114548288 A CN 114548288A CN 202210171304 A CN202210171304 A CN 202210171304A CN 114548288 A CN114548288 A CN 114548288A
- Authority
- CN
- China
- Prior art keywords
- image
- student network
- teacher
- loss function
- regressor
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Image Analysis (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请公开了模型训练方法和装置,涉及图像处理技术领域。该方法的一具体实施方式包括:响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络;基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络;基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络;将第二师生网络中的学生网络确定为识别模型。该实施方式有效提升了训练得到的识别模型的准确性和鲁棒性。
Description
技术领域
本申请涉及计算机技术领域,具体涉及图像处理技术领域,尤其涉及一种模型训练、图像识别方法和装置。
背景技术
目前所设计的高精度模型严重依赖大量的标记数据,一旦缺少足够的训练数据,模型的准确性和鲁棒性都会受到极大限制。然而,实际场景中的数据标注并不是一件容易的事,不仅费时费力而且可能出现标注错误的情况。目前有不少合成数据集,它们具有大量容易获取且标注准确的图片,但现实世界的数据集(目标域)与合成数据集(源域)的图片在纹理和背景上有着巨大差异,这就使得直接使用合成数据集训练的模型在实际应用中泛化性很差。因此无监督领域自适应问题,即如何利用标注准确但存在数据分布差异的源域数据来指导模型在目标域上的训练,具有重要的实际价值。
目前,对抗训练是领域自适应任务中常见的解决方案之一。通过在特征空间上的不断对抗博弈,使得源域和目标域的数据分布差异尽可能小。这样基于源域数据训练的模型,就可以应用于目标域数据上。
发明内容
本申请实施例提供了一种模型训练方法、装置、设备以及存储介质。
根据第一方面,本申请实施例提供了一种模型训练方法,该方法包括:响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络;基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络;基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络;将第二师生网络中的学生网络确定为识别模型。
根据第二方面,本申请实施例提供了一种图像识别方法,该方法包括:获取包含目标对象的待识别图像;将待识别图像输入识别模型,得到目标对象的关键点信息,其中,识别模型是如上述第一方面任一实现方式描述的方法得到的识别模型。
根据第三方面,本申请实施例提供了一种模型训练装置,该装置包括:获取数据模块,被配置成响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络;第一训练模块,被配置成基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络;第二训练模块,被配置成基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络;模型确定模块,被配置成将第二师生网络中的学生网络确定为识别模型。
根据第四方面,本申请实施例提供了一种图像识别装置,该装置包括:获取图像模块,被配置成获取包含目标对象的待识别图像;识别图像模块,被配置成将待识别图像输入识别模型,得到目标对象的关键点信息,其中,识别模型是如上述第一方面任一实现方式描述的方法得到的识别模型。
根据第五方面,本申请实施例提供了一种电子设备,该电子设备包括一个或多个处理器;存储装置,其上存储有一个或多个程序,当一个或多个程序被该一个或多个处理器执行,使得一个或多个处理器实现如第一方面或第二方面中任一实现方式描述的方法。
根据第六方面,本申请实施例提供了一种计算机可读介质,其上存储有计算机程序,该程序被处理器执行时实现如第一方面或第二方面中任一实现方式描述的方法。
本申请通过响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络;基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络;基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络;将第二师生网络中的学生网络确定为识别模型,即通过具有多分支的师生网络基于样本对集分别进行自我训练(self-training)和对抗训练,以缩小实际图像和合成图像的差异,提高伪标签的精度,进而提升训练得到的识别模型的准确性和鲁棒性。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其他特征将通过以下的说明书而变得容易理解。
附图说明
图1是本申请可以应用于其中的示例性系统架构图;
图2是根据本申请的模型训练方法的一个实施例的流程图;
图3是根据本申请的模型训练方法的一个应用场景的示意图;
图4是根据本申请的模型训练方法的又一个实施例的流程图;
图5是根据本申请的图像识别方法的又一个实施例的流程图;
图6是根据本申请的模型训练装置的一个实施例的示意图;
图7是根据本申请的图像识别装置的一个实施例的示意图;
图8是适于用来实现本申请实施例的服务器的计算机系统的结构示意图。
具体实施方式
以下结合附图对本申请的示范性实施例做出说明,其中包括本申请实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本申请的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本申请。
图1示出了可以应用本申请的模型训练方法的实施例的示例性系统架构100。
如图1所示,系统架构100可以包括终端设备101、102、103,网络104和服务器105。网络104用以在终端设备101、102、103和服务器105之间提供通信链路的介质。网络104可以包括各种连接类型,例如有线、无线通信链路或者光纤电缆等等。
终端设备101、102、103通过网络104与服务器105交互,以接收或发送消息等。终端设备101、102、103上可以安装有各种通讯客户端应用,例如,图像识别类应用、通讯类应用等。
终端设备101、102、103可以是硬件,也可以是软件。当终端设备101、102、103为硬件时,可以是具有显示屏的各种电子设备,包括但不限于手机和笔记本电脑。当终端设备101、102、103为软件时,可以安装在上述所列举的电子设备中。其可以实现成多个软件或软件模块(例如用来提供模型训练的服务),也可以实现成单个软件或软件模块。在此不做具体限定。
服务器105可以是提供各种服务的服务器,例如,响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络;基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络;基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络;将第二师生网络中的学生网络确定为识别模型。
需要说明的是,服务器105可以是硬件,也可以是软件。当服务器105为硬件时,可以实现成多个服务器组成的分布式服务器集群,也可以实现成单个服务器。当服务器为软件时,可以实现成多个软件或软件模块(例如用来提供训练模型的服务),也可以实现成单个软件或软件模块。在此不做具体限定。
需要指出的是,本公开的实施例所提供的模型训练方法可以由服务器105执行,也可以由终端设备101、102、103执行,还可以由服务器105和终端设备101、102、103彼此配合执行。相应地,模型训练的装置包括的各个部分(例如各个单元、子单元、模块、子模块)可以全部设置于服务器105中,也可以全部设置于终端设备101、102、103中,还可以分别设置于服务器105和终端设备101、102、103中。
应该理解,图1中的终端设备、网络和服务器的数目仅仅是示意性的。根据实现需要,可以具有任意数目的终端设备、网络和服务器。
图2示出了可以应用于本申请的模型训练方法的实施例的流程示意图200。在本实施例中,模型训练方法包括以下步骤:
步骤201,响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络。
在本实施例中,执行主体(如图1中所示的服务器105或终端设备101、102、103)可以从本地或远端的存储有样本对图像的服务器获取样本对集,并在获取到样本对集后,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络。
其中,样本对包括包含相同目标对象的第一图像和第二图像,第一图像为合成图像,第二图像为实际采集图像,即通过图像采集设备对目标对象进行图像采集得到的图像,第一图像已标注标签,标签用于指示目标对象的关键点信息,第二图像未标注标签。
这里,第一图像和第二图像所包括的相同的目标对象可以是任意的目标对象,例如,手、人脸等,本申请对此不作限定。
需要指出的是,若第一图像和第二图像所包括的目标对象均为手,则第一图像和第二图像的手势可以相同,也可以不同,本申请对此不作限定。
其中,师生网络,即教师-学生网络,属于迁移学习的一种,对于教师-学生网络,教师网络往往是一个更加复杂的网络,具有非常好的性能和泛化能力,可以用这个网络来作为一个soft target来指导另外一个更加简单的学生网络来学习,使得更加简单、参数运算量更少的学生网络也能够具有和教师网络相近的性能。
这里,学生网络和教师网络结构相同,学生网络包括特征提取器、主回归器和对抗回归器。
步骤202,基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,所述第二图像的伪标签由教师网络提供。
在本实施例中,执行主体可以根据学生网络中的主回归器针对第一图像的输出结果、针对第二图像的输出结果,即预测值,对抗回归器针对第一图像的输出结果,即预测值,第一图像的标签、第二图像的伪标签,构建目标损失函数,并最小化目标损失函数以对初始师生网络进行训练,得到第一师生网络。
其中,第二图像的伪标签由预训练的教师网络提供。
这里,执行主体可以根据第一损失函数、第二损失函数和第三损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。其中,第一损失函数可以基于主回归器针对第一图像的预测值及第一图像的标签确定,第二损失函数可以基于主回归器针对第二图像的预测值及第二图像的伪标签确定,第三损失函数可以基于对抗回归器针对第一图像的预测值与第一图像的标签确定。
需要指出的是,若学生网络还包括输出回归器,则第三损失函数也可以基于输出回归器针对第一图像的预测值及第一图像的标签确定,本申请对此不作限定。
进一步地,目标损失函数还可以包括正则化损失函数、全局损失函数等可进一步提升模型性能的损失函数。
此外,需要说明的是,训练过程中,学生网络的模型参数通过正常的SGD(Stochastic Gradient Descent,随机梯度下降)算法进行更新,而教师网络并不参与梯度反向传播。教师网络的模型参数θ′通过对学生网络的模型参数θ的EMAN(ExponentialMoving Average Normalization,指数移动平均归一化)来更新,具体如下式所示:
θ′=mθ′+(1-m)θ
μ′=mμ′+(1-m)μ
σ′2=mσ′2+(1-m)σ2
其中,μ,σ2分别是BN的均值和方差。
步骤203,基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络。
在本实施例中,执行主体可以首先基于样本对集中的第二图像,保持学生网络的其余参数不变,对第一师生网络中的学生网络的对抗回归器进行训练,得到初始第二师生网络,再基于样本对集中的第二图像,保持学生网络的其余参数不变,对初始第二师生网络中的特征提取器进行训练,得到第二师生网络;也可以首先基于样本对集中的第二图像,保持学生网络的其余参数不变,对第一师生网络中的学生网络的特征提取器进行训练,得到初始第二师生网络,再基于样本对集中的第二图像,保持学生网络的其余参数不变,对初始第二师生网络中的对抗回归器进行训练,得到第二师生网络,本申请对此不作限定。
步骤204,将第二师生网络中的学生网络确定为识别模型。
在本实施例中,执行主体在得到第二师生网络后,可将第二师生网络中的学生网络确定为识别模型。
其中,识别模型可用于包含目标对象的图像的识别。
继续参见图3,图3是根据本实施例的模型训练的方法的应用场景的一个示意图。
在图3的应用场景中,执行主体301响应于获取到样本对集302,对于每一样本对,将样本对输入初始师生网络303中的学生网络和教师网络,其中,样本对包括包含相同目标对象的第一图像和第二图像,例如,第一图像和第二图像均为手势图像,且二者手势相同,第一图像为合成图像,第二图像为实际采集图像,第一图像已标注标签,标签用于指示目标对象的关键点信息,第二图像未标注标签,学生网络包括特征提取器、主回归器和对抗回归器;基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络303进行训练,得到第一师生网络304;基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络305;将第二师生网络中的学生网络确定为识别模型306。
本公开的模型训练的方法,通过响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络;基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络;基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络;将第二师生网络中的学生网络确定为识别模型,有效提升了训练得到的识别模型的准确性和鲁棒性。
进一步参考图4,其示出了图2所示的模型训练方法的又一个实施例的流程400。在本实施例中,流程400可包括以下步骤:
步骤401,响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络。
在本实施例中,步骤401的实现细节和技术效果,可以参考对步骤201的描述,在此不再赘述。
步骤402,基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,输出回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
在本实施例中,学生网络还包括输出回归器,执行主体可以根据学生网络中的主回归器针对第一图像的输出结果、针对第二图像的输出结果,即预测值,对抗回归器针对第一图像的输出结果,即预测值,输出回归器针对第一图像的输出结果,即预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数,并最小化目标损失函数以对初始师生网络进行训练,得到第一师生网络。
这里,执行主体可以根据第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。其中,第一损失函数可以基于主回归器针对第一图像的预测值及第一图像的标签确定,第二损失函数可以基于主回归器针对第二图像的预测值及第二图像的伪标签确定,第三损失函数可以基于对抗回归器针对第一图像的预测值与第一图像的标签确定,第四损失函数可以基于输出回归器针对第一图像的预测值与第一图像的标签确定。
在一些可选的方式中,基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,输出回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,包括:基于第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
在本实现方式中,执行主体可以根据第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。其中,第一损失函数可以基于主回归器针对第一图像的预测值及第一图像的标签确定,第二损失函数可以基于主回归器针对第二图像的预测值及第二图像的伪标签确定,第三损失函数可以基于输出回归器针对第一图像的预测值及对抗回归器针对第一图像的预测值确定,第四损失函数可以基于输出回归器针对第一图像的预测值与第一图像的标签确定。
具体地,对于样本对集,可将样本对集中的第一图像构成的集合作为源域样本标签为将样本对集中的第二图像构成的集合作为目标域样本对于每一样本对,将其输入初始师生网络中的学生网络,可得到相对应的特征图(Fs,Ft)和热图(Hs,Ht),具体如下式所示:
Fs=ψ(xs),Hs=f(Fs)
Ft=ψ(xt),Ht=f(Ft)
其中,ψ是特征提取器,f是回归器。
目标损失函数可通过下式表示:
其中,LT(f0(ψ(xs)),ys)表征第一损失函数,表征第二损失函数,LT(f1(ψ(xs)),f2(ψ(xs)))表征第三损失函数,LT(f1(ψ(xs)),ys)表征第四损失函数;λ1,λ2表征损失的权重;f0指示主回归器、f1指示输出回归器、f2指示对抗回归器。
这里,为了后续的对抗训练,可以采用KL散度来计算热图损失。首先定义空间概率分布PT(Hk),k∈{1,2,…,K},它对空间维度上每个关键点k的热图Hk∈RH×W进行归一化:
用σ表示空间softmax函数:
然后使用KL散度来计算loss值:
其中,Hs=f(ψ(xs))∈RK×W×H,是标签ys中每个关键点k的热图。由于使用KL(Kullback–Leibler divergence,KL散度)散度不会引起数值爆炸,在后续的计算中,默认使用KL代替MSE(Mean Squared Error,均方误差)。
源域中预测值与标签之间的损失为:
Ls=LT(Hs,ys)
该实现方式通过基于第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,进而基于第一师生网络确定识别模型,进一步提升了确定出的识别模型的鲁棒性。
在一些可选的方式中,基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络,包括:保持学生网络中的特征提取器、主回归器和输出回归器的参数不变,基于样本对集中的第二图像,最小化第五损失函数对第一师生网路进行训练,得到初始第二师生网络;保持学生网络中的各回归器的参数不变,基于样本对集中的第二图像,最小化第六损失函数对初始第二师生网络进行训练,得到第二师生网络。
在本实现方式中,执行主体可以首先保持学生网络中的特征提取器、主回归器和输出回归器的参数不变,基于样本对集中的第二图像,最小化第五损失函数对第一师生网络进行训练,得到初始第二师生网络。其中,第五损失函数基于第一师生网络针对第二图像的预测值和伪标签确定。
具体可通过下式表示:
进一步地,执行主体保持各回归器的参数不变,基于样本对集中的第二图像,最小化第六损失函数对初始第二师生网络进行训练,得到第二师生网络,其中,第六损失函数基于输出回归器针对第二图像中的预测值和对抗回归器针对第二图像的预测值确定。
具体可通过下式表示:
其中,输出回归器和对抗回归器,用来实现对抗训练,f1是输出回归器,f2是对抗回归器。当师生网络在目标域上预测出错时,预测错误的位置并不是在像素空间均匀分布的。例如,当模型预测手部姿势时,关键点更有可能位于手部区域,出现在背景中的概率接近于零。因此,使用错误概率分布,来使对抗回归器更关注那些概率高的位置。简单的说,就是利用输出空间概率上的稀疏性来引导对抗回归器的优化,让回归器更关注那些出现概率大的位置。
然后我们可以获得:
为了缓解了对抗训练的优化困难,将对抗训练中同一目标的极大极小化转换成两个相反目标的最小化。这两个相反的目标是分别为特征提取器和对抗回归器设计的。对抗回归器的目标是尽量减少对抗回归器f2的预测值和ground false prediction之间的损失。特征提取的目标是最小化f1和f2之间的损失。计算公式为:
该实现方式通过基于样本对集中的第二图像,最小化第五损失函数对学生网络中的对抗回归器进行训练,得到初始第二师生网络;基于样本对集中的第二图像,最小化第六损失函数对第二师生网络中的特征提取器进行训练,得到第二师生网络,根据基于第二师生网络确定识别模型,进一步提升了伪标签的精度,进而进一步提升了模型的鲁棒性。
在一些可选的方式中,第五损失函数通过以下方式确定:基于第一师生网络中教师网络针对第二图像的伪标签与学生网络中输出回归器针对第二图像的预测值的和,得到初始数值;基于初始数值与学生网络中对抗回归器针对第二图像的预测值的差,确定第五损失函数。
在本实现方式中,执行主体可以根据第一师生网络中教师网络针对第二图像的伪标签与第一师生网络中学生网络中输出回归器针对第二图像的预测值的和,得到初始数值;基于初始数值与学生网络中对抗回归器针对第二图像的预测值的差,确定第五损失函数。
该实现方式通过基于第一师生网络中教师网络针对第二图像的伪标签与学生网络中输出回归器针对第二图像的预测值的和,得到初始数值;基于初始数值与学生网络中对抗回归器针对第二图像的预测值的差,确定第五损失函数,有助于提升确定出的第五损失函数的准确性。
在一些可选的方式中,基于第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,包括:基于第一损失函数、第二损失函数、第三损失函数、第四损失函数和全局损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
在本实现方式中,执行主体可以根据第一损失函数、第二损失函数、第三损失函数、第四损失函数和全局损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,其中,全局损失函数基于样本对集中的第一图像和第二图像的特征分布确定。
具体地,目标损失函数可通过下式表示:
其中,λ1,λ2,λ3表征损失的权重。
该实现方式通过基于第一损失函数、第二损失函数、第三损失函数、第四损失函数和全局损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,进而基于第一师生网络确定识别模型,利用全局损失函数在一定程度上弥补了目标域和源域间隙,并减轻了噪声对伪标签的影响,进一步提升了确定出的识别模型的鲁棒性。
在一些可选的方式中,基于第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,包括:基于第一损失函数、第二损失函数、第三损失函数、第四损失函数、全局损失函数和正则化损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
在本实现方式中,执行主体可以根据第一损失函数、第二损失函数、第三损失函数、第四损失函数、全局损失函数和正则化损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
具体地,目标损失函数可通过下式表示:
其中,λ1,λ2,λ3,λ4表征损失的权重,H(xt;w)是主回归器的输出。
该实现方式通过基于第一损失函数、第二损失函数、第三损失函数、第四损失函数、全局损失函数和正则化损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,进而基于第一师生网络确定识别模型,利用正则化损失函数进一步减轻了噪声对伪标签的影响,同时加快了收敛速度,进一步提升了模型的鲁棒性,同时提高了模型训练效率。
步骤403,基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络。
在本实施例中,步骤403的实现细节和技术效果,可以参考对步骤203的描述,在此不再赘述。
步骤404,将第二师生网络中的学生网络确定为识别模型。
在本实施例中,步骤404的实现细节和技术效果,可以参考对步骤204的描述,在此不再赘述。
本申请的上述实施例,与图2对应的实施例相比,本实施例中的模型训练方法的流程400提现了基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,输出回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,进而基于第一师生网络,确定识别模型,有助于进一步提升得到的识别模型的准确性和鲁棒性。
继续参考图5,示出了根据本申请的图像识别方法的一个实施例的流程500。该图像识别方法,包括以下步骤:
步骤501,获取包含目标对象的待识别图像。
在本实施例中,执行主体可通过有线或无线连接方式获取包含目标对象的待识别对象。
其中,无线连接方式可以包括但不限于3G/4G连接、WiFi连接、蓝牙连接、WiMAX连接、Zigbee连接、UWB(ultra wideband)连接、以及其他现在已知或将来开发的无线连接方式。
步骤502,将待识别图像输入识别模型,得到目标对象的关键点信息。
在本实施例中,执行主体在获取到待识别图像后,可将待识别图像输入识别模型,以得到与待识别图像对应的目标对象的关键点信息。其中,识别模型是如图2对应的实施例描述的方法得到的识别模型,这里不再赘述。
本公开实施例提供的图像识别方法,通过获取待识别图像;将待识别图像输入识别模型,得到待识别图像对应的目标对象的关键点信息,其中,识别模型是如图2实施例描述的方法得到的识别模型,有助于提升对待识别图像进行识别的准确性。
进一步参考图6,作为对上述各图所示方法的实现,本申请提供了一种模型训练装置的一个实施例,该装置实施例与图2所示的方法实施例相对应,该装置具体可以应用于各种电子设备中。
如图6所示,本实施例的模型训练装置600包括:获取数据模块601、第一训练模块602、第二训练模块603和模型确定模块604。
其中,获取数据模块601,可被配置成响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络。
第一训练模块602,可被配置成基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
第二训练模块603,可被配置成基于样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络。
模型确定模块604,可被配置成将第二师生网络中的学生网络确定为识别模型。
在本实施例的一些可选的方式中,第一训练模块进一步被配置成:基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,输出回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
在本实施例的一些可选的方式中,第一训练模块进一步被配置成:基于第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
在本实施例的一些可选的方式中,第一训练模块进一步被配置成:基于第一损失函数、第二损失函数、第三损失函数、第四损失函数和全局损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
在本实施例的一些可选的方式中,第一训练模块进一步被配置成:基于第一损失函数、第二损失函数、第三损失函数、第四损失函数、全局损失函数和正则化损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
在本实施例的一些可选的方式中,第二训练模块进一步被配置成:保持学生网络中的特征提取器、主回归器和输出回归器的参数不变,基于样本对集中的第二图像,最小化第五损失函数对第一师生网络进行训练,得到初始第二师生网络;保持学生网络中的各回归器的参数不变,基于样本对集中的第二图像,最小化第六损失函数对初始第二师生网络进行训练,得到第二师生网络。
在本实施例的一些可选的方式中,第五损失函数通过以下方式确定:基于第一师生网络中教师网络针对第二图像的伪标签与学生网络中输出回归器针对第二图像的预测值的和,得到初始数值;基于初始数值与学生网络中对抗回归器针对第二图像的预测值的差,确定第五损失函数。
进一步参考图7,作为对上述各图所示方法的实现,本公开提供了一种图像识别装置的一个实施例,该装置实施例与图5所示的方法实施例相对应,该装置具体可以应用于各种电子设备中。
如图7所示,本实施例的图像识别装置700包括:获取图像模块701和识别图像模块702。
其中,获取图像模块701,可被配置成获取包含目标对象的待识别图像。
识别图像模块702,可被配置成将待识别图像输入识别模型,得到目标对象的关键点信息。
根据本申请的实施例,本申请还提供了一种电子设备和一种可读存储介质。
如图8所示,是根据本申请实施例的模型训练的方法的电子设备的框图。
800是根据本申请实施例的模型训练的方法的电子设备的框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本申请的实现。
如图8所示,该电子设备包括:一个或多个处理器801、存储器802,以及用于连接各部件的接口,包括高速接口和低速接口。各个部件利用不同的总线互相连接,并且可以被安装在公共主板上或者根据需要以其它方式安装。处理器可以对在电子设备内执行的指令进行处理,包括存储在存储器中或者存储器上以在外部输入/输出装置(诸如,耦合至接口的显示设备)上显示GUI的图形信息的指令。在其它实施方式中,若需要,可以将多个处理器和/或多条总线与多个存储器和多个存储器一起使用。同样,可以连接多个电子设备,各个设备提供部分必要的操作(例如,作为服务器阵列、一组刀片式服务器、或者多处理器系统)。图8中以一个处理器801为例。
存储器802即为本申请所提供的非瞬时计算机可读存储介质。其中,所述存储器存储有可由至少一个处理器执行的指令,以使所述至少一个处理器执行本申请所提供的模型训练的方法。本申请的非瞬时计算机可读存储介质存储计算机指令,该计算机指令用于使计算机执行本申请所提供的模型训练的方法。
存储器802作为一种非瞬时计算机可读存储介质,可用于存储非瞬时软件程序、非瞬时计算机可执行程序以及模块,如本申请实施例中的模型训练的方法对应的程序指令/模块(例如,附图6所示的获取数据模块601、第一训练模块602、第二训练模块603和模型确定模块604)。处理器801通过运行存储在存储器802中的非瞬时软件程序、指令以及模块,从而执行服务器的各种功能应用以及数据处理,即实现上述方法实施例中的模型训练的方法。
存储器802可以包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需要的应用程序;存储数据区可存储模型训练的的电子设备的使用所创建的数据等。此外,存储器802可以包括高速随机存取存储器,还可以包括非瞬时存储器,例如至少一个磁盘存储器件、闪存器件、或其他非瞬时固态存储器件。在一些实施例中,存储器802可选包括相对于处理器801远程设置的存储器,这些远程存储器可以通过网络连接至模型训练的的电子设备。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
模型训练的方法的电子设备还可以包括:输入装置803和输出装置804。处理器801、存储器802、输入装置803和输出装置804可以通过总线或者其他方式连接,图8中以通过总线连接为例。
输入装置803可接收输入的数字或字符信息,例如触摸屏、小键盘、鼠标、轨迹板、触摸板、指示杆、一个或者多个鼠标按钮、轨迹球、操纵杆等输入装置。输出装置804可以包括显示设备、辅助照明装置(例如,LED)和触觉反馈装置(例如,振动电机)等。该显示设备可以包括但不限于,液晶显示器(LCD)、发光二极管(LED)显示器和等离子体显示器。在一些实施方式中,显示设备可以是触摸屏。
此处描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、专用ASIC(专用集成电路)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
这些计算程序(也称作程序、软件、软件应用、或者代码)包括可编程处理器的机器指令,并且可以利用高级过程和/或面向对象的编程语言、和/或汇编/机器语言来实施这些计算程序。如本文使用的,术语“机器可读介质”和“计算机可读介质”指的是用于将机器指令和/或数据提供给可编程处理器的任何计算机程序产品、设备、和/或装置(例如,磁盘、光盘、存储器、可编程逻辑装置(PLD)),包括,接收作为机器可读信号的机器指令的机器可读介质。术语“机器可读信号”指的是用于将机器指令和/或数据提供给可编程处理器的任何信号。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。
根据本申请实施例的技术方案,有效提升了训练得到的识别模型准确性和鲁棒性。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发申请中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本申请公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本申请保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本申请的精神和原则之内所作的修改、等同替换和改进等,均应包含在本申请保护范围之内。
Claims (12)
1.一种模型训练方法,所述方法包括:
响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络,其中,所述样本对包括包含相同目标对象的第一图像和第二图像,所述第一图像为合成图像,所述第二图像为实际采集图像,所述第一图像已标注标签,所述标签用于指示目标对象的关键点信息,所述第二图像未标注标签,学生网络包括特征提取器、主回归器和对抗回归器;
基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,所述第二图像的伪标签由教师网络提供;
基于所述样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络;
将所述第二师生网络中的学生网络确定为识别模型。
2.根据权利要求1所述的方法,其中,所述学生网络还包括:输出回归器,以及所述基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,包括:
基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,输出回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
3.根据权利要求2所述的方法,其中,所述基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,输出回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,包括:
基于第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。其中,所述第一损失函数基于主回归器针对第一图像的预测值及第一图像的标签确定,所述第二损失函数基于主回归器针对第二图像的预测值及第二图像的伪标签确定,所述第三损失函数基于输出回归器针对第一图像的预测值及对抗回归器针对第一图像的预测值确定,所述第四损失函数基于输出回归器针对第一图像的预测值及第一图像的标签确定。
4.根据权利要求2所述的方法,其中,所述基于第一损失函数、第二损失函数、第三损失函数和第四损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,包括:
基于第一损失函数、第二损失函数、第三损失函数、第四损失函数和全局损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,其中,所述全局损失函数基于样本对集中的第一图像和第二图像的特征分布确定。
5.根据权利要求4所述的方法,其中,基于第一损失函数、第二损失函数、第三损失函数、第四损失函数和全局损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,包括:
基于第一损失函数、第二损失函数、第三损失函数、第四损失函数、全局损失函数和正则化损失函数构建目标损失函数以对初始师生网络进行训练,得到第一师生网络。
6.根据权利要求2所述的方法,其中,所述基于所述样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络,包括:
保持学生网络中的特征提取器、主回归器和输出回归器的参数不变,基于样本对集中的第二图像,最小化第五损失函数对第一师生网络进行训练,得到初始第二师生网络,其中,所述第五损失函数基于第一师生网络针对第二图像的预测值与伪标签确定;
保持学生网络中的各回归器的参数不变,基于样本对集中的第二图像,最小化第六损失函数对初始第二师生网络进行训练,得到第二师生网络,其中,第六损失函数基于输出回归器针对第二图像中的预测值和对抗回归器针对第二图像的预测值确定。
7.根据权利要求6所述的方法,其中,所述第五损失函数通过以下方式确定:
基于第一师生网络中教师网络针对第二图像的伪标签与学生网络中输出回归器针对第二图像的预测值的和,得到初始数值;
基于所述初始数值与学生网络中对抗回归器针对第二图像的预测值的差,确定第五损失函数。
8.一种识别方法,所述方法包括:
获取包含目标对象的待识别图像;
将所述待识别图像输入识别模型,得到所述目标对象的关键点信息,其中,所述识别模型是如权利要求1-7之一所述的方法得到的识别模型。
9.一种模型训练装置,包括:
获取数据模块,被配置成响应于获取到样本对集,对于每一样本对,将样本对输入初始师生网络中的学生网络和教师网络,其中,所述样本对包括包含相同目标对象的第一图像和第二图像,所述第一图像为合成图像,所述第二图像为实际采集图像,所述第一图像已标注标签,所述标签用于指示目标对象的关键点信息,所述第二图像未标注标签,学生网络包括特征提取器、主回归器和对抗回归器;
第一训练模块,被配置成基于学生网络中的主回归器针对第一图像的预测值、针对第二图像的预测值,对抗回归器针对第一图像的预测值,以及第一图像的标签和第二图像的伪标签,构建目标损失函数以对初始师生网络进行训练,得到第一师生网络,所述第二图像的伪标签由教师网络提供;
第二训练模块,被配置成基于所述样本对集中的第二图像,分别对第一师生网络中的学生网络的特征提取器和对抗回归器进行训练,得到第二师生网络;
模型确定模块,被配置成将所述第二师生网络中的学生网络确定为识别模型。
10.一种图像识别装置,包括:
获取图像模块,被配置成获取包含目标对象的待识别图像;
识别图像模块,被配置成将所述待识别图像输入识别模型,得到所述目标对象的关键点信息,其中,所述识别模型是如权利要求1-7之一所述的方法得到的识别模型。
11.一种电子设备,其特征在于,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-8中任一项所述的方法。
12.一种存储有计算机指令的非瞬时计算机可读存储介质,其特征在于,所述计算机指令用于使所述计算机执行权利要求1-8中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210171304.9A CN114548288A (zh) | 2022-02-24 | 2022-02-24 | 模型训练、图像识别方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210171304.9A CN114548288A (zh) | 2022-02-24 | 2022-02-24 | 模型训练、图像识别方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114548288A true CN114548288A (zh) | 2022-05-27 |
Family
ID=81677623
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210171304.9A Pending CN114548288A (zh) | 2022-02-24 | 2022-02-24 | 模型训练、图像识别方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114548288A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116051926A (zh) * | 2023-01-12 | 2023-05-02 | 北京百度网讯科技有限公司 | 图像识别模型的训练方法、图像识别方法和装置 |
-
2022
- 2022-02-24 CN CN202210171304.9A patent/CN114548288A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116051926A (zh) * | 2023-01-12 | 2023-05-02 | 北京百度网讯科技有限公司 | 图像识别模型的训练方法、图像识别方法和装置 |
CN116051926B (zh) * | 2023-01-12 | 2024-04-16 | 北京百度网讯科技有限公司 | 图像识别模型的训练方法、图像识别方法和装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109145781B (zh) | 用于处理图像的方法和装置 | |
CN111639710A (zh) | 图像识别模型训练方法、装置、设备以及存储介质 | |
US11741355B2 (en) | Training of student neural network with teacher neural networks | |
CN111539514A (zh) | 用于生成神经网络的结构的方法和装置 | |
JP2022058915A (ja) | 画像認識モデルをトレーニングするための方法および装置、画像を認識するための方法および装置、電子機器、記憶媒体、並びにコンピュータプログラム | |
CN110520871A (zh) | 训练机器学习模型 | |
JP7262571B2 (ja) | 知識グラフのベクトル表現生成方法、装置及び電子機器 | |
CN112001180A (zh) | 多模态预训练模型获取方法、装置、电子设备及存储介质 | |
CN111767359B (zh) | 兴趣点分类方法、装置、设备以及存储介质 | |
CN111079945B (zh) | 端到端模型的训练方法及装置 | |
CN111950291A (zh) | 语义表示模型的生成方法、装置、电子设备及存储介质 | |
CN110543558B (zh) | 问题匹配方法、装置、设备和介质 | |
CN111708876A (zh) | 生成信息的方法和装置 | |
CN111259671A (zh) | 文本实体的语义描述处理方法、装置及设备 | |
CN112580733B (zh) | 分类模型的训练方法、装置、设备以及存储介质 | |
CN111931067A (zh) | 兴趣点推荐方法、装置、设备和介质 | |
CN112541362B (zh) | 一种泛化处理的方法、装置、设备和计算机存储介质 | |
CN112507090A (zh) | 用于输出信息的方法、装置、设备和存储介质 | |
CN113537374A (zh) | 一种对抗样本生成方法 | |
CN111767833A (zh) | 模型生成方法、装置、电子设备及存储介质 | |
CN111695698A (zh) | 用于模型蒸馏的方法、装置、电子设备及可读存储介质 | |
CN109034199B (zh) | 数据处理方法及装置、存储介质和电子设备 | |
CN111966782B (zh) | 多轮对话的检索方法、装置、存储介质及电子设备 | |
CN111241838B (zh) | 文本实体的语义关系处理方法、装置及设备 | |
CN112380855A (zh) | 确定语句通顺度的方法、确定概率预测模型的方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |