CN111680631B - 模型训练方法及装置 - Google Patents

模型训练方法及装置 Download PDF

Info

Publication number
CN111680631B
CN111680631B CN202010520344.0A CN202010520344A CN111680631B CN 111680631 B CN111680631 B CN 111680631B CN 202010520344 A CN202010520344 A CN 202010520344A CN 111680631 B CN111680631 B CN 111680631B
Authority
CN
China
Prior art keywords
similarity
loss
max
training sample
auxiliary
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010520344.0A
Other languages
English (en)
Other versions
CN111680631A (zh
Inventor
熊凯
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Guangzhou Shiyuan Electronics Thecnology Co Ltd
Original Assignee
Guangzhou Shiyuan Electronics Thecnology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Guangzhou Shiyuan Electronics Thecnology Co Ltd filed Critical Guangzhou Shiyuan Electronics Thecnology Co Ltd
Priority to CN202010520344.0A priority Critical patent/CN111680631B/zh
Publication of CN111680631A publication Critical patent/CN111680631A/zh
Application granted granted Critical
Publication of CN111680631B publication Critical patent/CN111680631B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V40/00Recognition of biometric, human-related or animal-related patterns in image or video data
    • G06V40/10Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
    • G06V40/16Human faces, e.g. facial parts, sketches or expressions
    • G06V40/161Detection; Localisation; Normalisation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V40/00Recognition of biometric, human-related or animal-related patterns in image or video data
    • G06V40/10Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
    • G06V40/16Human faces, e.g. facial parts, sketches or expressions
    • G06V40/168Feature extraction; Face representation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V40/00Recognition of biometric, human-related or animal-related patterns in image or video data
    • G06V40/10Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
    • G06V40/16Human faces, e.g. facial parts, sketches or expressions
    • G06V40/172Classification, e.g. identification
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • General Physics & Mathematics (AREA)
  • Oral & Maxillofacial Surgery (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Evolutionary Computation (AREA)
  • Computational Linguistics (AREA)
  • Molecular Biology (AREA)
  • Biophysics (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Data Mining & Analysis (AREA)
  • Artificial Intelligence (AREA)
  • Multimedia (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Human Computer Interaction (AREA)
  • Probability & Statistics with Applications (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种模型训练方法及装置。其中,该方法包括:构造识别模型的损失函数,其中,损失函数包括辅助损失部分,辅助损失部分用于要求第一相似度大于第一相似度阈值,第一相似度为训练样本与该训练样本的目标原型之间的相似度;根据损失函数,训练识别模型。本发明解决了在相关技术中,在识别模型的训练与应用之间存在隔阂的技术问题。

Description

模型训练方法及装置
技术领域
本发明涉及人工智能领域,具体而言,涉及一种模型训练方法及装置。
背景技术
基于深度神经网络的模型训练目前主要有两种方式,以人脸识别为例:一种是基于多分类进行训练,即学习将一张人脸图像,分配到某个类别中,这里类别数量等于训练数据集中的人数;另一种是基于度量学习,比如学习判定两张人脸图像是否为同一人(也有基于三张或多张人脸图像进行学习的)。以上两种训练过程,就是模型在学习如何抽取人脸的特征。在理想情况下,同一个人的不同照片的特征,相互距离很近,而不同人的照片的特征,相互距离很远。
人脸识别是计算机视觉的一个研究方向,目前已有大量落地应用,比如人脸签到、人脸支付和门禁系统等。人脸识别模型,就是对每一张查询人脸图像,计算出对应的特征向量,然后将其与预先计算并存储的人脸库的特征一一比对,找到最近的人脸ID及其距离,若这个最近的距离,小于事先设定好的距离阈值,则识别成功,认为该查询人脸的身份就是返回的ID。但基于多分类的方式中,人脸识别模型的训练过程和人脸识别模型的应用之间存在隔阂,训练模型时,并不涉及到距离阈值的概念,而实际应用时,需要基于交叉验证的方式,确定合适的距离阈值。
因此,在相关技术中,在识别模型的训练与应用之间存在隔阂的问题。针对上述的问题,目前尚未提出有效的解决方案。
发明内容
本发明实施例提供了一种模型训练方法及装置,以至少解决在相关技术中,在识别模型的训练与应用之间存在隔阂的技术问题。
根据本发明实施例的一个方面,提供了一种模型训练方法,包括:构造识别模型的损失函数,其中,所述损失函数包括辅助损失部分,所述辅助损失部分用于要求第一相似度大于第一相似度阈值,所述第一相似度为训练样本与该训练样本的目标原型之间的相似度;根据所述损失函数,训练所述识别模型。
可选地,该方法还包括:通过以下方式,构造所述辅助损失部分:确定第二相似度,其中,所述第二相似度为与所述训练样本相似度最大的非目标原型与所述训练样本之间的相似度;根据所述第一相似度,第二相似度,以及所述第一相似度阈值,构造所述辅助损失部分。
可选地,根据所述第一相似度,第二相似度,以及所述第一相似度阈值,构造所述辅助损失部分,包括:通过以下公式,构造所述辅助损失部分:Laux=max(0,σ-Wyi*xi)+max(0,Wmax_j*xi-σ),其中,Laux为所述辅助损失部分的损失值,yi为所述训练样本所属的类别,xi为所述训练样本采用所述识别模型提取的特征,Wyi为所述目标原型的特征,Wyi*xi为所述第一相似度,Wmax_j为与所述训练样本相似度最大的非目标原型的特征,Wmax_j*xi为所述第二相似度,σ为所述第一相似度阈值。
可选地,根据所述第一相似度,第二相似度,以及所述第一相似度阈值,构造所述辅助损失部分,包括:确定第二相似度阈值,其中,所述第一相似度阈值比所述第二相似度阈值至少大相似度间隔;根据所述第一相似度,所述第二相似度,所述第一相似度阈值,以及所述第二相似度阈值,构造所述辅助损失部分。
可选地,根据所述第一相似度,所述第二相似度,所述第一相似度阈值,以及所述第二相似度阈值,构造所述辅助损失部分,包括:通过以下公式,构造所述辅助损失部分:Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL),且σHL=α>0,其中,Laux为所述辅助损失部分的损失值,yi为所述训练样本所属的类别,xi为所述训练样本采用所述识别模型提取的特征,Wyi为所述目标原型的特征,Wyi*xi为所述第一相似度,Wmax_j为与所述训练样本相似度最大的非目标原型的特征,Wmax_j*xi为所述第二相似度,σH为所述第一相似度阈值,σL为所述第二相似度阈值,α为所述相似度间隔。
可选地,所述第一相似度阈值和所述第二相似度阈值均通过以下方式至少之一确定:固定值;通过目标拟合函数得到的拟合值;通过最小化所述辅助损失部分得到的优化值。
可选地,根据所述第一相似度,所述第二相似度,所述第一相似度阈值,以及所述第二相似度阈值,构造所述辅助损失部分,包括:通过以下公式,构造所述辅助损失部分:Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL)+max(0,σLH+α),其中,Laux为所述辅助损失部分的损失值,yi为所述训练样本所属的类别,xi为所述训练样本采用所述识别模型提取的特征,Wyi为所述目标原型的特征,Wyi*xi为所述第一相似度,Wmax_j为与所述训练样本相似度最大的非目标原型的特征,Wmax_j*xi为所述第二相似度,σH为所述第一相似度阈值,σL为所述第二相似度阈值,α为预定的相似度间隔。
可选地,构造所述识别模型的所述损失函数包括:将所述辅助损失部分与其它损失部分进行加权组合的方式,构造所述识别模型的所述损失函数,其中,所述其它损失部分包括以下至少之一:三元组损失函数,交叉熵损失函数。
可选地,所述识别模型包括人脸识别模型。
根据本发明实施例的另一方面,还提供了一种模型训练装置,包括:构造模块,用于构造识别模型的损失函数,其中,所述损失函数包括辅助损失部分,所述辅助损失部分用于要求第一相似度大于第一相似度阈值,所述第一相似度为训练样本与该训练样本的目标原型之间的相似度;训练模块,用于根据所述损失函数,训练所述识别模型。
在本发明实施例中,采用构造包括辅助损失部分的损失函数的方式,通过该辅助损失部分要求第一相似度大于第一相似度阈值,其中,该第一相似度为训练样本与该训练样本的目标原型之间的相似度,达到了在识别模型的训练中引入相似度阈值的目的,从而实现了在应用识别模型之前不用设置相似度阈值的操作,消除了识别模型的训练与应用之间的隔阂的技术效果,进而解决了在相关技术中,在识别模型的训练与应用之间存在隔阂的技术问题。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1是根据本发明实施例的模型训练方法的流程图;
图2是根据本发明实施例提供的模型训练装置的示意性框图;
图3是根据本发明实施例提供的一种计算机终端300的示意性结构图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
根据本发明实施例,提供了一种模型训练方法的方法实施例,需要说明的是,在附图的流程图示出的步骤可以在诸如一组计算机可执行指令的计算机系统中执行,并且,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。
图1是根据本发明实施例的模型训练方法的流程图,如图1所示,该方法包括如下步骤:
步骤S102,构造识别模型的损失函数,其中,损失函数包括辅助损失部分,辅助损失部分用于要求第一相似度大于第一相似度阈值,第一相似度为训练样本与该训练样本的目标原型之间的相似度;需要说明的是,上述识别模型可以是用于识别多种对象的模型,例如,可以是用于识别人脸图像的人脸识别模型,也可以类比地应用于识别声音的声音识别模型,在以下的举例说明中,多以人脸识别模型为例进行说明。
步骤S104,根据损失函数,训练识别模型。
通过上述步骤,采用构造包括辅助损失部分的损失函数的方式,通过该辅助损失部分要求第一相似度大于第一相似度阈值,其中,该第一相似度为训练样本与该训练样本的目标原型之间的相似度,达到了在识别模型的训练中引入相似度阈值的目的,从而实现了在应用识别模型之前不用设置相似度阈值的操作,消除了识别模型的训练与应用之间的隔阂的技术效果,进而解决了在相关技术中,在识别模型的训练与应用之间存在隔阂的技术问题。
在对本实施例进行说明之前,对相关技术中对模型进行训练时一般所采用的损失函数进行说明。例如,训练多分类模型,常用到Softmax交叉熵损失函数:
其中,xi训练数据集中的第i个训练样本(例如,对于人脸识别而言,即是第i张人脸图像)的特征向量,它的类别(身份)记为yi,Wyi是全连接层中对应yi这个类别的权值向量(注:全连接层指深度神经网络的最后一层,其输出作为Soft max的输入),byi是偏置项。在具体处理时,往往将特征向量和权值向量都标准化处理(向量元素的平方和为1,此时两向量的内积等于其余弦相似度),并将偏置项byi省略(即置零)。
Soft max交叉熵损失函数的作用,就是引导深度神经网络抽取的特征xi,能被正确分类到类别yi。数学上,即让Wyi和xi的余弦相似度,尽可能大于所有Wj(j≠yi)和xi的余弦相似度,公式中的指数操作是为了加大数据间的差异,分母操作是归一化(取值在[0,1]之间)。归一化后log右边的式子,则可以理解为xi属于类别yi的预测概率,若该预测值接近1,则对应xi的损失值就趋近于0(log1=0),相反地,预测xi为类别yi的概率越小,对应的损失就越大。模型的训练过程,就是利用梯度下降法,让很多个样本的平均损失逐步降低,使得模型提取的特征,能够被正确分类。
在识别模型训练完毕后,全连接层的权值矩阵W就没有用了,后续应用时,只需要用到全连接层前面层的主干网络,用于提取待识别对象的特征向量。实际上,在人脸识别模型训练过程中,模型训练过程中的权值矩阵W,可以类比成训练数据集对应的人脸特征库,对于第i个人脸图片的特征xi,将其与人脸库W中的特征一一比对(前文计算Wyi和xi,Wj(j≠yi)和xi的余弦相似度的过程),并希望Wyi与xi最近(余弦相似度大),以得到较小的损失。这个过程中,并不涉及相似度阈值的概念,即并不要求Wyi和xi的余弦相似度要大于某个具体的阈值,仅要求Wyi和xi的相似度,要相对大于Wj(j≠yi)和xi的相似度,这就使得模型训练和应用之间存在隔阂。
为了消除隔阂,在本发明实施例中,提供了上述模型训练方法,在该模型训练方法中,用于模型训练的损失函数包括:辅助比对损失Laux,可以配合Soft max交叉熵损失函数使用,如通过加权求和组合使用:L=Lsoftmax+λLaux。具体地,Laux引入相似度阈值,明确要求Wyi和xi的余弦相似度比阈值大,Wj(j≠yi)和xi的余弦相似度比阈值小。
通过将相似度阈值(或转换为距离阈值)的选取,融入识别模型的训练过程中,不仅可以减轻或省去识别模型应用前的调参工作,而且,可以促进识别模型训练得更好,从而提取更有判别性的特征。
作为一种可选的实施例,在构造识别模型的损失函数时,可以通过以下方式,构造辅助损失部分:确定第二相似度,其中,第二相似度为与训练样本相似度最大的非目标原型与训练样本之间的相似度;根据第一相似度,第二相似度,以及第一相似度阈值,构造辅助损失部分。
可选地,根据第一相似度,第二相似度,以及第一相似度阈值,构造辅助损失部分时,可以通过以下公式,构造辅助损失部分:
Laux=max(0,σ-Wyi*xi)+max(0,Wmax_j*xi-σ),其中,Laux为辅助损失部分的损失值,yi为训练样本所属的类别,xi为训练样本采用识别模型提取的特征,Wyi为目标原型的特征,Wyi*xi为第一相似度,Wmax_j为与训练样本相似度最大的非目标原型的特征,Wmax_j*xi为第二相似度,σ为第一相似度阈值。
以人脸识别模型为例,训练样本为人脸图片,对于第i个人脸图片,其对应的辅助损失部分的计算方式为:Laux=max(0,σ-Wyi*xi)+max(0,Wmax_j*xi-σ),这里Wmax_j是指不考虑Wyi时,与xi相似度最大的权值向量Wj。该辅助损失部分的物理意义即是希望Wyi与xi的余弦相似度大于σ,Wj(j≠yi)和xi的余弦相似度小于σ,此时Laux的损失值为0。模型训练完毕后,可以直接以σ作为相似度阈值(或转换为距离阈值)判定是否识别成功。(注:若a>0,max(0,a)=a,否则max(0,a)=0)。
需要说明的是,上述公式及后续描述中,均以一对人脸对的损失计算为例进行说明,实际训练时,每次可以计算一个批次如128对人脸图像对的损失,并将其平均值作为最终损失。例如,为使得识别模型的训练更为准确,可以对批量样本对计算损失值,之后,对批量样本的损失值求平均。在针对批量样本对,根据损失函数确定的损失值求平均时,该批量样本对的数量可以结合训练的质量和效率的要求而灵活确定,比如,该批量样本对可以是128对等。
作为一种可选的实施例,在根据第一相似度,第二相似度,以及第一相似度阈值,构造辅助损失部分时,可以采用以下方式构造辅助损失部分:确定第二相似度阈值,其中,第一相似度阈值比第二相似度阈值至少大相似度间隔;根据第一相似度,第二相似度,第一相似度阈值,以及第二相似度阈值,构造辅助损失部分。
可选地,根据第一相似度,第二相似度,第一相似度阈值,以及第二相似度阈值,构造辅助损失部分时,可以通过以下公式,构造辅助损失部分:Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL),且σHL=α>0,其中,Laux为辅助损失部分的损失值,yi为训练样本所属的类别,xi为训练样本采用识别模型提取的特征,Wyi为目标原型的特征,Wyi*xi为第一相似度,Wmax_j为与训练样本相似度最大的非目标原型的特征,Wmax_j*xi为第二相似度,σH为第一相似度阈值,σL为第二相似度阈值,α为相似度间隔。
还是以人脸识别模型为例,训练样本为人脸图片,对于第i个人脸图片,其对应的辅助损失部分的计算方式为:Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL)。在该公式中,引入了两个阈值参数σH和σL,且σHL=α>0,在本可选实施例中,相当于上述Laux=max(0,σ-Wyi*xi)+max(0,Wmax_j*xi-σ)的辅助损失部分,引入了相似度间隔,引导Wyi与xi的余弦相似度,要比Wmax_j和xi的余弦相似度至少大α。
需要说明的是,上述第一相似度阈值和第二相似度阈值可以分别是一个固定值;该固定值是人为设定的,固定不变的。上述第一相似度阈值和第二相似度阈值可以是变化的,例如,可以是先给定一个初始值,之后,自适应性地调整。例如,可以通过目标拟合函数得到的拟合值。比如,以人脸识别模型为例,对于第i个人脸图片,辅助损失部分为:Laux=max(0,σ-Wyi*xi)+max(0,Wmax_j*xi-σ)。每张人脸图片可以构成两个人脸对:正样本对(xi,Wyi)和负样本对(xi,Wmax_j)。对于一批人脸图片,比如4096张,可以构造8192个人脸对。计算得到这些人脸对的接受者操作特性曲线(Receiver Operating Characteristic curve,简称为ROC)曲线,并拟合出目标FPR(如FPR=1e-6)对应的相似度阈值σ。
在该方式中,为了防止阈值更新时震荡,可以采用平滑策略,如σi=β*σi-1+(1-β)*σ,这里β取值在[0,1]之间,β越大代表更新越慢,σi-1和σi分别是第(i-1)次和第i次的更新值,σ是第i次计算的当前值。
为了防止阈值更新时震荡,可以部分复用用于计算ROC的人脸对。举例,可以分别利用训练集中人脸图片序号在1~4096,4097~8192,...区间对应的8192个人脸对,也可以分别利用序号在1~4096,2049~6144,...对应的8192个人脸对,后者因为有重叠的人脸对,可以减缓阈值更新震荡。
另外,上述第一相似度阈值和第二相似度阈值还可以通过最小化辅助损失部分得到的优化值。例如,可选地,根据第一相似度,第二相似度,第一相似度阈值,以及第二相似度阈值,构造辅助损失部分时,可以通过以下公式,构造辅助损失部分:Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL)+max(0,σLH+α),其中,Laux为辅助损失部分的损失值,yi为训练样本所属的类别,xi为训练样本采用识别模型提取的特征,Wyi为目标原型的特征,Wyi*xi为第一相似度,Wmax_j为与训练样本相似度最大的非目标原型的特征,Wmax_j*xi为第二相似度,σH为第一相似度阈值,σL为第二相似度阈值,α为预定的相似度间隔。
即在该方式中,引入相似度阈值σH,σL,相似度间隔α,其中,相似度阈值视为模型参数,基于梯度下降自动学习其大小,间隔α预先指定。
作为一种可选的实施例,构造识别模型的损失函数时,可以将辅助损失部分与其它损失部分进行加权组合的方式,构造识别模型的损失函数,其中,其它损失部分包括以下至少之一:三元组损失函数,交叉熵损失函数。
在以上的举例中,均是以多分类的模型训练为例进行说明的。实际上,在度量学习的模型训练中也可以应用上述辅助损失部分,来实现消除模型训练与模型应用之间的隔阂的问题。
在度量学习的人脸识别模型训练过程中,三元组损失函数,同样关注的是人脸特征相似度之间的相对大小,没有考虑绝对差异,从而存在模型训练和应用之间的隔阂。三元组损失函数形如Ltri=max(0,||xa-xp||2-||xa-xn||21),其中,(xa,xp)和(xa,xn)分别对应正负样本对,α1是距离间隔,由于这里采用的欧式距离,所以由引入相似度阈值,改为引入距离阈值。对应于上述,根据第一相似度,第二相似度,第一相似度阈值,以及第二相似度阈值,构造辅助损失部分的处理,可引入两个距离阈值,仍旧记为σL和σH,但此时是正样本对要小于σL,负样本对要大于σH,辅助损失部分的计算方式可以为:
Laux=max(0,||xa-xp||2L)+max(0,σH-||xa-xn||2)+max(0,σLH+α)。
需要说明的是,在人脸识别领域中,常用多分类的方式,以及三元组损失函数来训练模型。基于Soft max交叉熵损失函数的多分类训练方式,也可以和度量学习相结合,引出了多种Soft max变体损失函数。在本申请所提出的辅助损失部分,也能天然适配到这些变化的损失函数中使用。
综上,通过上述实施例及可选实施例,本申请可以实现以下效果:
(1)有利于模型学习更具判别性的特征,因为Soft max交叉熵损失函数等损失函数,主要关注距离或相似度的相对大小,忽略了应用过程中需要依赖绝对大小的需求,本申请中提出的辅助损失部分,配合原有损失,引导模型同时学习相对大小和绝对大小。
(2)由于引入了相似度阈值或距离阈值,基于本申请中提出的辅助损失部分训练的模型,可以在模型训练完毕后的应用阶段,省去或者减轻阈值的选取工作量。
(3)相较于度量学习中关注了绝对距离差异的比对损失(contrastive loss),本申请提出的辅助损失部分,在关注绝对差异的同时,还进一步提出了自适应和自学习的阈值策略,减少了训练过程中的参数调节工作。
上文中结合图1,详细描述了根据本申请实施例的模型训练方法,下面将结合图2,描述根据本申请实施例的模型训练装置和计算机终端。
图2是根据本发明实施例提供的模型训练装置的示意性框图,如图2所示,该模型训练装置200包括:构造模块202和训练模块204,下面对该模型训练装置200进行说明。
构造模块202,用于构造识别模型的损失函数,其中,损失函数包括辅助损失部分,辅助损失部分用于要求第一相似度大于第一相似度阈值,第一相似度为训练样本与该训练样本的目标原型之间的相似度;训练模块204,连接至上述构造模块202,用于根据损失函数,训练识别模型。
可选地,作为一个实施例,上述构造模块202,还用于通过以下方式,构造辅助损失部分:确定第二相似度,其中,第二相似度为与训练样本相似度最大的非目标原型与训练样本之间的相似度;根据第一相似度,第二相似度,以及第一相似度阈值,构造辅助损失部分。
可选地,作为一个实施例,上述构造模块202,还用于通过以下公式,构造辅助损失部分:Laux=max(0,σ-Wyi*xi)+max(0,Wmax_j*xi-σ),其中,Laux为辅助损失部分的损失值,yi为训练样本所属的类别,xi为训练样本采用识别模型提取的特征,Wyi为目标原型的特征,Wyi*xi为第一相似度,Wmax_j为与训练样本相似度最大的非目标原型的特征,Wmax_j*xi为第二相似度,σ为第一相似度阈值。
可选地,作为一个实施例,上述构造模块202,还用于确定第二相似度阈值,其中,第一相似度阈值比第二相似度阈值至少大相似度间隔;根据第一相似度,第二相似度,第一相似度阈值,以及第二相似度阈值,构造辅助损失部分。
可选地,作为一个实施例,上述构造模块202,还用于通过以下公式,构造辅助损失部分:Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL),且σHL=α>0,其中,Laux为辅助损失部分的损失值,yi为训练样本所属的类别,xi为训练样本采用识别模型提取的特征,Wyi为目标原型的特征,Wyi*xi为第一相似度,Wmax_j为与训练样本相似度最大的非目标原型的特征,Wmax_j*xi为第二相似度,σH为第一相似度阈值,σL为第二相似度阈值,α为相似度间隔。
可选地,作为一个实施例,该装置还包括:确定模块,用于通过以下方式至少之一,确定第一相似度阈值和第二相似度阈值:固定值;通过目标拟合函数得到的拟合值;通过最小化辅助损失部分得到的优化值。
可选地,作为一个实施例,上述构造模块202,还用于通过以下公式,构造辅助损失部分:Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL)+max(0,σLH+α),其中,Laux为辅助损失部分的损失值,yi为训练样本所属的类别,xi为训练样本采用识别模型提取的特征,Wyi为目标原型的特征,Wyi*xi为第一相似度,Wmax_j为与训练样本相似度最大的非目标原型的特征,Wmax_j*xi为第二相似度,σH为第一相似度阈值,σL为第二相似度阈值,α为预定的相似度间隔。
可选地,作为一个实施例,上述构造模块202,还用于将辅助损失部分与其它损失部分进行加权组合的方式,构造识别模型的损失函数,其中,其它损失部分包括以下至少之一:三元组损失函数,交叉熵损失函数。
可选地,作为一个实施例,识别模型包括人脸识别模型。
应理解,根据本申请实施例的装置中的各个单元的上述和其它操作和/或功能分别为了实现上述各个方法中的相应流程,为了简洁,在此不再赘述。
图3是根据本发明实施例提供的一种计算机终端300的示意性结构图。图3所示的计算机终端300包括处理器310,处理器310可以从存储器中调用并运行计算机程序,以实现本申请实施例中的方法。
可选地,如图3所示,计算机终端300还可以包括存储器320。其中,处理器310可以从存储器320中调用并运行计算机程序,以实现本申请实施例中的方法。
其中,存储器320可以是独立于处理器310的一个单独的器件,也可以集成在处理器310中。
可选地,如图3所示,计算机终端300还可以包括收发器330,处理器310可以控制该收发器330与其他设备进行通信,具体地,可以向其他设备发送信息或数据,或接收其他设备发送的信息或数据。
其中,收发器330可以包括发射机和接收机。收发器330还可以进一步包括天线,天线的数量可以为一个或多个。
可选地,该计算机终端300可以实现本申请实施例的各个方法中实现的相应流程,为了简洁,在此不再赘述。
应理解,本申请实施例的处理器可能是一种集成电路芯片,具有信号的处理能力。在实现过程中,上述方法实施例的各步骤可以通过处理器中的硬件的集成逻辑电路或者软件形式的指令完成。上述的处理器可以是通用处理器、数字信号处理器(Digital SignalProcessor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。可以实现或者执行本申请实施例中的公开的各方法、步骤及逻辑框图。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。结合本申请实施例所公开的方法的步骤可以直接体现为硬件译码处理器执行完成,或者用译码处理器中的硬件及软件模块组合执行完成。软件模块可以位于随机存储器,闪存、只读存储器,可编程只读存储器或者电可擦写可编程存储器、寄存器等本领域成熟的存储介质中。该存储介质位于存储器,处理器读取存储器中的信息,结合其硬件完成上述方法的步骤。
可以理解,本申请实施例中的存储器可以是易失性存储器或非易失性存储器,或可包括易失性和非易失性存储器两者。其中,非易失性存储器可以是只读存储器(Read-Only Memory,ROM)、可编程只读存储器(Programmable ROM,PROM)、可擦除可编程只读存储器(Erasable PROM,EPROM)、电可擦除可编程只读存储器(Electrically EPROM,EEPROM)或闪存。易失性存储器可以是随机存取存储器(Random Access Memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(Static RAM,SRAM)、动态随机存取存储器(Dynamic RAM,DRAM)、同步动态随机存取存储器(Synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(Double Data RateSDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(Enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(Synchlink DRAM,SLDRAM)和直接内存总线随机存取存储器(DirectRambus RAM,DR RAM)。应注意,本文描述的系统和方法的存储器旨在包括但不限于这些和任意其它适合类型的存储器。
应理解,上述存储器为示例性但不是限制性说明,例如,本申请实施例中的存储器还可以是静态随机存取存储器(Static RAM,SRAM)、动态随机存取存储器(Dynamic RAM,DRAM)、同步动态随机存取存储器(Synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(Double Data Rate SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(Enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(Synch Link DRAM,SLDRAM)以及直接内存总线随机存取存储器(Direct Rambus RAM,DR RAM)等等。也就是说,本申请实施例中的存储器旨在包括但不限于这些和任意其它适合类型的存储器。
本申请实施例还提供了一种计算机可读存储介质,用于存储计算机程序。
可选的,该计算机可读存储介质可应用于本申请实施例中的计算机终端,并且该计算机程序使得计算机终端执行本申请实施例的各个方法中的相应流程,为了简洁,在此不再赘述。
本申请实施例还提供了一种计算机程序产品,包括计算机程序指令。
可选的,该计算机程序产品可应用于本申请实施例中的计算机终端,并且该计算机程序指令使得计算机终端执行本申请实施例的各个方法中的相应流程,为了简洁,在此不再赘述。
本申请实施例还提供了一种计算机程序。
可选的,该计算机程序可应用于本申请实施例中的计算机终端,当该计算机程序在计算机上运行时,使得计算机终端执行本申请实施例的各个方法中的相应流程,为了简洁,在此不再赘述。
上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
在本发明的上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述的部分,可以参见其他实施例的相关描述。
在本申请所提供的几个实施例中,应该理解到,所揭露的技术内容,可通过其它的方式实现。其中,以上所描述的装置实施例仅仅是示意性的,例如所述单元的划分,可以为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,单元或模块的间接耦合或通信连接,可以是电性或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可为个人计算机、服务器或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、移动硬盘、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述仅是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也应视为本发明的保护范围。

Claims (7)

1.一种模型训练方法,其特征在于,包括:
构造识别模型的损失函数,其中,所述损失函数包括辅助损失部分,所述辅助损失部分用于要求第一相似度大于第一相似度阈值,所述第一相似度为训练样本与该训练样本的目标原型之间的相似度,所述识别模型包括人脸识别模型,所述训练样本为人脸图片;
根据所述损失函数,训练所述识别模型;
所述方法还包括,通过以下方式,构造所述辅助损失部分:
确定第二相似度,其中,所述第二相似度为与所述训练样本相似度最大的非目标原型与所述训练样本之间的相似度;
根据所述第一相似度,第二相似度,以及所述第一相似度阈值,构造所述辅助损失部分;
其中,根据所述第一相似度,第二相似度,以及所述第一相似度阈值,构造所述辅助损失部分,包括:
通过以下公式,构造所述辅助损失部分:
Laux=max(0,σ-Wyi*xi)+max(0,Wmax_j*xi-σ),
其中,Laux为所述辅助损失部分的损失值,yi为所述训练样本所属的类别,xi为所述训练样本采用所述识别模型提取的特征,Wyi为所述目标原型的特征,Wyi*xi为所述第一相似度,Wmax_j为与所述训练样本相似度最大的非目标原型的特征,Wmax_j*xi为所述第二相似度,σ为所述第一相似度阈值。
2.根据权利要求1所述的方法,其特征在于,构造所述识别模型的所述损失函数包括:
将所述辅助损失部分与其它损失部分进行加权组合的方式,构造所述识别模型的所述损失函数,其中,所述其它损失部分包括以下至少之一:三元组损失函数,交叉熵损失函数。
3.一种模型训练方法,其特征在于,包括:
构造识别模型的损失函数,其中,所述损失函数包括辅助损失部分,所述辅助损失部分用于要求第一相似度大于第一相似度阈值,所述第一相似度为训练样本与该训练样本的目标原型之间的相似度,所述识别模型包括人脸识别模型,所述训练样本为人脸图片;
根据所述损失函数,训练所述识别模型;
所述方法还包括,通过以下方式,构造所述辅助损失部分:
确定第二相似度,其中,所述第二相似度为与所述训练样本相似度最大的非目标原型与所述训练样本之间的相似度;
根据所述第一相似度,第二相似度,以及所述第一相似度阈值,构造所述辅助损失部分;
根据所述第一相似度,第二相似度,以及所述第一相似度阈值,构造所述辅助损失部分,包括:
确定第二相似度阈值,其中,所述第一相似度阈值比所述第二相似度阈值至少大相似度间隔;
根据所述第一相似度,所述第二相似度,所述第一相似度阈值,以及所述第二相似度阈值,构造所述辅助损失部分;
根据所述第一相似度,所述第二相似度,所述第一相似度阈值,以及所述第二相似度阈值,构造所述辅助损失部分,包括:
通过以下公式,构造所述辅助损失部分:
Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL),且σHL=α>0,
其中,Laux为所述辅助损失部分的损失值,yi为所述训练样本所属的类别,xi为所述训练样本采用所述识别模型提取的特征,Wyi为所述目标原型的特征,Wyi*xi为所述第一相似度,Wmax_j为与所述训练样本相似度最大的非目标原型的特征,Wmax_j*xi为所述第二相似度,σH为所述第一相似度阈值,σL为所述第二相似度阈值,α为所述相似度间隔。
4.根据权利要求3所述的方法,其特征在于,所述第一相似度阈值和所述第二相似度阈值均通过以下方式至少之一确定:
固定值;
通过目标拟合函数得到的拟合值;
通过最小化所述辅助损失部分得到的优化值。
5.根据权利要求3或4所述的方法,其特征在于,构造所述识别模型的所述损失函数包括:
将所述辅助损失部分与其它损失部分进行加权组合的方式,构造所述识别模型的所述损失函数,其中,所述其它损失部分包括以下至少之一:三元组损失函数,交叉熵损失函数。
6.一种模型训练方法,其特征在于,包括:
构造识别模型的损失函数,其中,所述损失函数包括辅助损失部分,所述辅助损失部分用于要求第一相似度大于第一相似度阈值,所述第一相似度为训练样本与该训练样本的目标原型之间的相似度,所述识别模型包括人脸识别模型,所述训练样本为人脸图片;
根据所述损失函数,训练所述识别模型;
所述方法还包括,通过以下方式,构造所述辅助损失部分:
确定第二相似度,其中,所述第二相似度为与所述训练样本相似度最大的非目标原型与所述训练样本之间的相似度;
根据所述第一相似度,第二相似度,以及所述第一相似度阈值,构造所述辅助损失部分;
根据所述第一相似度,所述第二相似度,所述第一相似度阈值,以及所述第二相似度阈值,构造所述辅助损失部分,包括:
通过以下公式,构造所述辅助损失部分:
Laux=max(0,σH-Wyi*xi)+max(0,Wmax_j*xi-σL)+max(0,σLH+α),
其中,Laux为所述辅助损失部分的损失值,yi为所述训练样本所属的类别,xi为所述训练样本采用所述识别模型提取的特征,Wyi为所述目标原型的特征,Wyi*xi为所述第一相似度,Wmax_j为与所述训练样本相似度最大的非目标原型的特征,Wmax_j*xi为所述第二相似度,σH为所述第一相似度阈值,σL为所述第二相似度阈值,α为预定的相似度间隔。
7.根据权利要求6所述的方法,其特征在于,构造所述识别模型的所述损失函数包括:
将所述辅助损失部分与其它损失部分进行加权组合的方式,构造所述识别模型的所述损失函数,其中,所述其它损失部分包括以下至少之一:三元组损失函数,交叉熵损失函数。
CN202010520344.0A 2020-06-09 2020-06-09 模型训练方法及装置 Active CN111680631B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010520344.0A CN111680631B (zh) 2020-06-09 2020-06-09 模型训练方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010520344.0A CN111680631B (zh) 2020-06-09 2020-06-09 模型训练方法及装置

Publications (2)

Publication Number Publication Date
CN111680631A CN111680631A (zh) 2020-09-18
CN111680631B true CN111680631B (zh) 2023-12-22

Family

ID=72454229

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010520344.0A Active CN111680631B (zh) 2020-06-09 2020-06-09 模型训练方法及装置

Country Status (1)

Country Link
CN (1) CN111680631B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113627361B (zh) * 2021-08-13 2023-08-08 北京百度网讯科技有限公司 人脸识别模型的训练方法、装置及计算机程序产品

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108197670A (zh) * 2018-01-31 2018-06-22 国信优易数据有限公司 伪标签生成模型训练方法、装置及伪标签生成方法及装置
CN109165566A (zh) * 2018-08-01 2019-01-08 中国计量大学 一种基于新型损失函数的人脸识别卷积神经网络训练方法
CN109816092A (zh) * 2018-12-13 2019-05-28 北京三快在线科技有限公司 深度神经网络训练方法、装置、电子设备及存储介质
CN110197102A (zh) * 2018-02-27 2019-09-03 腾讯科技(深圳)有限公司 人脸识别方法及装置
WO2019227672A1 (zh) * 2018-05-28 2019-12-05 平安科技(深圳)有限公司 说话人分离模型训练方法、两说话人分离方法及相关设备
EP3582150A1 (en) * 2018-06-13 2019-12-18 Fujitsu Limited Method of knowledge transferring, information processing apparatus and storage medium

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108197670A (zh) * 2018-01-31 2018-06-22 国信优易数据有限公司 伪标签生成模型训练方法、装置及伪标签生成方法及装置
CN110197102A (zh) * 2018-02-27 2019-09-03 腾讯科技(深圳)有限公司 人脸识别方法及装置
WO2019227672A1 (zh) * 2018-05-28 2019-12-05 平安科技(深圳)有限公司 说话人分离模型训练方法、两说话人分离方法及相关设备
EP3582150A1 (en) * 2018-06-13 2019-12-18 Fujitsu Limited Method of knowledge transferring, information processing apparatus and storage medium
CN109165566A (zh) * 2018-08-01 2019-01-08 中国计量大学 一种基于新型损失函数的人脸识别卷积神经网络训练方法
CN109816092A (zh) * 2018-12-13 2019-05-28 北京三快在线科技有限公司 深度神经网络训练方法、装置、电子设备及存储介质

Also Published As

Publication number Publication date
CN111680631A (zh) 2020-09-18

Similar Documents

Publication Publication Date Title
CN111767900B (zh) 人脸活体检测方法、装置、计算机设备及存储介质
CN111523621A (zh) 图像识别方法、装置、计算机设备和存储介质
CN110188829B (zh) 神经网络的训练方法、目标识别的方法及相关产品
CN111133453B (zh) 人工神经网络
US20190220653A1 (en) Compact models for object recognition
WO2019233226A1 (zh) 人脸识别方法、分类模型训练方法、装置、存储介质和计算机设备
WO2020258981A1 (zh) 基于眼底图像的身份信息处理方法及设备
US10255487B2 (en) Emotion estimation apparatus using facial images of target individual, emotion estimation method, and non-transitory computer readable medium
KR101412727B1 (ko) 얼굴 인식 장치 및 방법
CN111898735A (zh) 蒸馏学习方法、装置、计算机设备和存储介质
CN110321964B (zh) 图像识别模型更新方法及相关装置
WO2019196626A1 (zh) 媒体处理方法及相关装置
CN111680631B (zh) 模型训练方法及装置
CN116110100A (zh) 一种人脸识别方法、装置、计算机设备及存储介质
CN111310516A (zh) 一种行为识别方法和装置
JP6600288B2 (ja) 統合装置及びプログラム
CN115795355B (zh) 一种分类模型训练方法、装置及设备
CN113326832B (zh) 模型训练、图像处理方法、电子设备及存储介质
CN115578765A (zh) 目标识别方法、装置、系统及计算机可读存储介质
CN111680636A (zh) 模型训练方法及装置
CN116071472A (zh) 图像生成方法及装置、计算机可读存储介质、终端
CN113327212B (zh) 人脸驱动、模型的训练方法、装置、电子设备及存储介质
CN113283388B (zh) 活体人脸检测模型的训练方法、装置、设备及存储介质
CN114677535A (zh) 域适应图像分类网络的训练方法、图像分类方法及装置
CN108764106B (zh) 基于级联结构的多尺度彩色图像人脸比对方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant