CN111523422B - 一种关键点检测模型训练方法、关键点检测方法和装置 - Google Patents

一种关键点检测模型训练方法、关键点检测方法和装置 Download PDF

Info

Publication number
CN111523422B
CN111523422B CN202010294788.7A CN202010294788A CN111523422B CN 111523422 B CN111523422 B CN 111523422B CN 202010294788 A CN202010294788 A CN 202010294788A CN 111523422 B CN111523422 B CN 111523422B
Authority
CN
China
Prior art keywords
model
thermodynamic diagram
training
loss function
image sample
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010294788.7A
Other languages
English (en)
Other versions
CN111523422A (zh
Inventor
赵佳
李骊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing HJIMI Technology Co Ltd
Original Assignee
Beijing HJIMI Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing HJIMI Technology Co Ltd filed Critical Beijing HJIMI Technology Co Ltd
Priority to CN202010294788.7A priority Critical patent/CN111523422B/zh
Publication of CN111523422A publication Critical patent/CN111523422A/zh
Application granted granted Critical
Publication of CN111523422B publication Critical patent/CN111523422B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V40/00Recognition of biometric, human-related or animal-related patterns in image or video data
    • G06V40/10Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
    • G06V40/16Human faces, e.g. facial parts, sketches or expressions
    • G06V40/161Detection; Localisation; Normalisation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V40/00Recognition of biometric, human-related or animal-related patterns in image or video data
    • G06V40/10Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
    • G06V40/16Human faces, e.g. facial parts, sketches or expressions
    • G06V40/172Classification, e.g. identification
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • General Physics & Mathematics (AREA)
  • Oral & Maxillofacial Surgery (AREA)
  • Human Computer Interaction (AREA)
  • Multimedia (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例公开一种关键点检测模型训练方法、关键点检测方法和装置,在关键点检测模型训练中同时利用有标注和无标注图像样本。在训练时,通过生成模型根据获取的未标注图像样本生成第一热力图,根据获取的标注图像样本中标注的关键点坐标确定第二热力图。通过判别模型计算第一热力图和未标注图像样本的第一匹配度,以及计算第二热力图和标注图像样本的第二匹配度,根据第一匹配度和第二匹配度构建对抗损失函数。从而根据该对抗损失函数对生成模型和判别模型进行训练。由于在训练时采用了大量无标注图像样本,仅需要少量的标注图像样本,降低了人工标注成本,提高了模型训练的效率。同时,减小了人工标注带来的标注偏差,提升了模型训练的效果。

Description

一种关键点检测模型训练方法、关键点检测方法和装置
技术领域
本申请涉及机器学习领域,特别是涉及一种关键点检测模型训练方法、关键点检测方法和装置。
背景技术
关键点检测是指在图像或视频中确定出感兴趣的关键位置的坐标。例如,在人脸关键点检测中,从包扩人脸的图像中确定出内外眼角、鼻尖、嘴角等关键位置的坐标。关键点检测是计算机视觉应用的重要组成部分,对于人脸识别、表情识别、姿态识别等领域有着重要的作用。
目前的关键点检测方法主要基于深度神经网络的关键点检测模型实现,关键点检测模型主要是基于端到端的全监督方式进行训练,即训练数据全部为标注数据。
然而,这种训练方式需要大量的关键点标注数据,但是人工标注成本高,耗时长,而且不同的标注者对相同的关键点给出的坐标往往存在偏差,进而可能导致训练得到的模型难以准确的预测关键点位置。
发明内容
为了解决上述技术问题,本申请提供了一种关键点检测模型训练方法、关键点检测方法和装置,仅需要少量的标注图像样本,大大降低了人工标注成本,提高了模型训练的效率。同时,尽量减小了人工标注带来的标注偏差,提升了模型训练的效果。
本申请实施例公开了如下技术方案:
第一方面,本申请实施例提供一种关键点检测模型训练方法,所述方法包括:
通过生成模型根据获取的未标注图像样本生成第一热力图;
根据获取的标注图像样本中标注的关键点坐标确定第二热力图;
通过判别模型计算所述第一热力图和所述未标注图像样本的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;
根据所述第一匹配度和所述第二匹配度构建对抗损失函数;
根据所述对抗损失函数对所述生成模型和判别模型进行训练。
可选的,根据所述对抗损失函数对所述生成模型和判别模型进行训练,包括:
将所述对抗损失函数作为所述判别模型的损失函数对所述判别模型进行训练;
根据所述对抗损失函数和散度损失函数构建所述生成模型的损失函数,对所述生成模型进行训练;所述散度损失函数用于表示所述标注图像样本的所述第二热力图与第三热力图之间的差距;所述第三热力图是所述生成模型根据所述标注图像样本生成的。
可选的,所述生成模型的损失函数为LG=LKL-λLadv;其中,LG为所述生成模型的损失函数,LKL为所述散度损失函数,Ladv为所述对抗损失函数,λ为损失权重乘积。
可选的,所述根据获取的标注图像样本中标注的关键点坐标确定第二热力图,包括:
根据所述关键点坐标计算均值和均方差;
根据所述均值和均方差计算所述第二热力图。
第二方面,本申请实施例提供一种关键点检测方法,所述方法包括:
获取待检测图像;
通过生成模型生成热力图;所述生成模型是根据标注图像样本和未标注图像样本,与判别模型进行对抗训练得到的;所述对抗训练的方式为通过生成模型根据所述未标注图像样本生成第一热力图;根据所述标注图像样本中标注的关键点坐标确定第二热力图;通过判别模型计算所述第一热力图和所述未标注图像样本的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;根据所述第一匹配度和所述第二匹配度构建对抗损失函数;根据所述对抗损失函数对所述生成模型和判别模型进行训练;
根据所述热力图确定关键点坐标。
第三方面,本申请实施例提供一种关键点检测模型训练装置,所述装置包括:
生成单元,用于通过生成模型根据获取的未标注图像样本生成第一热力图;
确定单元,用于根据获取的标注图像样本中标注的关键点坐标确定第二热力图;
计算单元,用于通过判别模型计算所述第一热力图和所述未标注图像样本的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;
构建单元,用于根据所述第一匹配度和所述第二匹配度构建对抗损失函数;
训练单元,用于根据所述对抗损失函数对所述生成模型和判别模型进行训练。
可选的,所述训练单元,用于:
将所述对抗损失函数作为所述判别模型的损失函数对所述判别模型进行训练;
根据所述对抗损失函数和散度损失函数构建所述生成模型的损失函数,对所述生成模型进行训练;所述散度损失函数用于表示所述标注图像样本的所述第二热力图与第三热力图之间的差距;所述第三热力图是所述生成模型根据所述标注图像样本生成的。
可选的,所述生成模型的损失函数为LG=LKL-λLadv;其中,LG为所述生成模型的损失函数,LKL为所述散度损失函数,Ladv为所述对抗损失函数,λ为损失权重乘积。
可选的,所述确定单元,用于:
根据所述关键点坐标计算均值和均方差;
根据所述均值和均方差计算所述第二热力图。
第四方面,本申请实施例提供一种关键点检测装置,所述装置包括:
获取单元,用于获取待检测图像;
生成单元,用于通过生成模型生成热力图;所述生成模型是根据标注图像样本和未标注图像样本,与判别模型进行对抗训练得到的;所述对抗训练的方式为通过生成模型根据所述未标注图像样本生成第一热力图;根据所述标注图像样本中标注的关键点坐标确定第二热力图;通过判别模型计算所述第一热力图和所述未标注图像样本的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;根据所述第一匹配度和所述第二匹配度构建对抗损失函数;根据所述对抗损失函数对所述生成模型和判别模型进行训练;
确定单元,用于根据所述热力图确定关键点坐标。
由上述技术方案可以看出,本申请实施例引入对抗训练机制,使得关键点检测模型在训练中可以同时利用有标注和无标注图像样本,在训练时,通过生成模型根据获取的未标注图像样本生成第一热力图,以及根据获取的标注图像样本中标注的关键点坐标确定第二热力图。然后,通过判别模型计算第一热力图和未标注图像的第一匹配度,以及计算第二热力图和标注图像样本的第二匹配度,根据第一匹配度和第二匹配度构建对抗损失函数。从而根据该对抗损失函数对生成模型和判别模型进行训练。由于在训练时采用了无标注图像样本进行半监督训练,从而仅需要少量的标注图像样本,大大降低了人工标注成本,提高了模型训练的效率。同时,尽量减小了人工标注带来的标注偏差,提升了模型训练的效果。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例提供的一种关键点检测模型训练方法的流程图;
图2为本申请实施例提供的一种关键点检测模型训练方法的流程图;
图3为本申请实施例提供的一种关键点检测方法的流程图;
图4为本申请实施例提供的一种关键点检测模型训练装置的结构图;
图5为本申请实施例提供的一种关键点检测装置的结构图。
具体实施方式
为了使本技术领域的人员更好地理解本申请方案,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
相关技术中,模型通常采用端到端的全监督方式训练。深度神经网络模型参数数量通常很大(百万级以上),因此全监督训练需要大量的关键点标注数据,即人脸图像和对应的人工标注脸部关键点坐标。但是人工标注成本高,耗时长,而且不同的标注者对相同的关键点给出的坐标往往存在偏差。这也导致目前大规模人脸关键点标注数据集很少,很难满足训练需求。另一方面,未标注的人脸图像大量存在且可免费获得,如果能够合理利用这些未标注数据,关键点检测模型效果将会有较大的提升。
为了解决上述技术问题,本申请实施例提供一种关键点检测模型训练方法、关键点检测方法和装置,引入对抗训练机制,使得关键点检测模型在训练中可以同时利用有标注和无标注图像样本,从而仅需要少量的标注图像样本,大大降低了人工标注成本,提高了模型训练的效率。同时,尽量减小了人工标注带来的标注偏差,提升了模型训练的效果。
本申请实施例提供的方法主要应用于人脸识别、表情识别、姿态识别等方面,为了便于介绍,后续将主要以人脸识别为例进行介绍。
接下来,将结合附图对本申请实施例提供的关键点检测模型训练方法进行详细介绍。
参见图1,图1示出了一种关键点检测模型训练方法的流程图,方法包括:
S101、通过生成模型根据获取的未标注图像样本生成第一热力图。
本申请实施例需要从训练数据集中选取训练数据,训练数据中包括标注图像样本。每个标注图像样本(例如人脸图像)记为其中i为样本索引,Il∈Rh×w×3为标注人脸图像(图像宽h,高w,有3个颜色通道),sl∈Rk×2为标注人脸图像中人工标注的关键点的坐标向量(k为关键点个数)。
每个未标注样本记为其中i为样本索引,Iu∈Rh×w×3为未标注人脸图像(图像宽h,高w,有3个颜色通道)。
在每次训练迭代时,通常需要从训练集中随机抽取小批量样本IB,样本数量可根据实际情况自主选择。小批量标注图像样本记为小批量未标注图像样本记为/>
需要说明的是,本申请实施例引入对抗训练机制对关键点检测模型进行训练,此时,关键点检测模型可以是生成式对抗网络(Generative Adversarial Networks,GAN)模型,包括生成模型Gθ和判别模型Dφ,生成模型的输入是训练数据,在人脸识别中则输入的是人脸图像I∈Rh×w×3,包括标注人脸图像和未标注人脸图像。生成模型的结构可采用类似编码--解码(即沙漏状)的网络结构,网络参数为θ。
对于未标注人脸图像,生成模型可以根据未标注图像样本生成其对应的热力图,例如第一热力图Hu∈Rh×w×k,即Hu=Gθ(I)。该第一热力图是生成模型根据未标注图像样本预测得到的,I为表示未标注图像样本的Iu
其中,热力图一般包含k层(k为关键点个数),每层热力图描述一个对应关键点的概率分布。热力图的长和宽与输入人脸图像相同。
S102、根据获取的标注图像样本中标注的关键点坐标确定第二热力图。
由于标注图像样本中标注出了关键点坐标,故S102中确定第二热力图时,无需生成模型预测其对应的热力图,而是可以直接根据标注图像样本中标注的关键点坐标推到出其对应的热力图,例如第二热力图。
本申请实施例提供一种根据关键点坐标推导热力图的方式,该方式可以是根据标注的关键点坐标计算均值和均方差,然后根据得到的均值和协方差计算得到第二热力图。其中,第二热力图例如可以是以该均值和协方差为单位矩阵的二维高斯概率分布,即Hl=Gaussian(μ,Σ)。
其中,Hl为第二热力图,μ为均值,μ=[xl,yl],Σ为协方差,
假设这个关键点通过生成模型预测得到的第一热力图为Hu,可以使用2维(2Dimensional,2D)softargmax操作(一种计算极大值的方式)计算该关键点坐标的期望值,即其中γ为温度因子。然后计算热力图Hu的协方差其中x=(x,y)表示Hu中某点的坐标。
通过均值为协方差为/>的二维高斯概率分布来近似热力图Hu,即
S103、通过判别模型计算所述第一热力图和所述未标注图像的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度。
判别模型的输入是人脸图像与热力图组成的图像对,输出是一个分数m,m=Dφ([I,H]),代表热力与其对应的人脸图像的匹配度。判别模型的具体结构可自主设计,网络参数为φ。
当未标注图像样本和其对应的第一热力图输入判别模型时,输入的图像对可以表示为[Iu,Hu],通过判别模型计算第一热力图和未标注图像的第一匹配度,即Dφ([Iu,Gθ(Iu)])。当标注图像样本和其对应的第二热力图输入判别模型时,输入的图像对可以表示为[Il,Hl],通过判别模型计算第二热力图和标注图像样本的第二匹配度,即Dφ([Il,Hl])。
需要说明的是,分数高表示人脸图像与热力图的匹配度低,也就是热力图并不是人脸图像的标注热力图。因此,对于Dφ([Iu,Gθ(Iu)])来说,Dφ([Iu,Gθ(Iu)])越高越好,说明判别模型的判别能力较高,可以判别出第一热力图是预测得到的,并非实际标注得到的。相反的,分数低表示人脸图像与热力图的匹配度高,对于Dφ([Il,Hl])来说,Dφ([Il,Hl])越低越好,说明判别模型的判别能力较高,可以判别出第二热力图是根据实际标注得到的。
S104、根据所述第一匹配度和所述第二匹配度构建对抗损失函数。
对抗损失函数可以表示为:其中,Ladv为对抗损失函数,/>为第二匹配度,/>为第一匹配度,/>为小批量样本中的标注图像样本,/>为标注图像样本对应的第二热力图,/>为小批量样本中的未标注图像样本,/>为未标注图像样本对应的第一热力图。
S105、根据所述对抗损失函数对所述生成模型和判别模型进行训练。
在进行训练时,采用的是对抗训练机制,即在优化生成模型时,固定判别模型的参数,从而根据生成模型的损失函数更新生成模型的参数Gθ;在优化判别模型时,固定生成模型的参数,从而根据网络模型的损失函数更新判别模型的参数Dφ,直到生成模型和判别模型符合条件则结束训练,即二者各自的损失函数所代表的损失最小。
在训练过程中,采用循环的方式,不断的调整生成模型和判别模型的参数,使得生成模型生成的第一热力图更加接近标注热力图(即根据标注的关键点坐标推导出的热力图,而非生成模型预测出的热力图),使得判别模型难以区分预测出的热力图和标注热力图。同时使得判别模型的判别能力不断提高,能够准确的区分预测出的热力图和标注热力图。
需要说明的是,在更新生成模型的参数时,生成模型的参数可以按照θ←θ-α▽θLG的形式进行更新,其中,LG表示生成模型的损失函数,▽φLG表示生成模型的损失函数LG对于参数θ的梯度,α代表学习率;在更新判别模型的参数时,判别模型的参数可以按照φ←φ-α▽φLD的形式进行更新,其中,LD表示判别模型的损失函数,▽φLD表示判别模型的损失函数LD对于参数φ的梯度,α代表学习率。
由上述技术方案可以看出,本申请实施例引入对抗训练机制,使得关键点检测模型在训练中可以同时利用有标注和无标注图像样本,在训练时,通过生成模型根据获取的未标注图像样本生成第一热力图,以及根据获取的标注图像样本中标注的关键点坐标确定第二热力图。然后,通过判别模型计算第一热力图和未标注图像的第一匹配度,以及计算第二热力图和标注图像样本的第二匹配度,根据第一匹配度和第二匹配度构建对抗损失函数。从而根据该对抗损失函数对生成模型和判别模型进行训练。由于在训练时采用了无标注图像样本进行半监督训练,从而仅需要少量的标注图像样本,大大降低了人工标注成本,提高了模型训练的效率。同时,尽量减小了人工标注带来的标注偏差,提升了模型训练的效果。另外,未标注图像样本一般是免费的,从而降低了模型训练的费用成本。
相关技术中,在进行模型训练时,模型在图像样本上的损失即损失函数通常可表示为其中l表示损失函数,/>表示模型预测的关键点坐标,si表示关键点坐标真值(即标注值),p通常取值为2(对应二范数或欧式距离)或1(对应一范数或曼哈顿距离)。这种计算损失的方式关注模型最终预测值(预测的关键点坐标)与真值(即标注值)间的差异,但是忽略了模型预测值的概率分布信息。也就是说,尽管模型的预测值与标注值接近,但模型可能对预测值的置信度不高,当输入图像稍作改变时,预测值可能就会出现较大偏差,鲁棒性较差。
例如,图像A和B的标注值是10,通过训练得到的模型对图像A进行预测,其得到的预测值可能是9,但是对图像B进行预测,得到的预测值可能是11。二者的预测值分别与对应的标注值比较接近,但是预测值却又较大偏差。
因此,为了提高模型预测的置信度,在本申请实施例中,在根据对抗损失函数训练生成模型时,在生成模型的损失函数中引入预测值的概率分布信息。其中,在损失函数中引入概率分布信息的方式可以包括多种,本申请实施例提供的方式可以是通过KL散度(又称相对熵)体现概率分布信息。
需要说明的是,S105中在对判别模型进行训练时,可以将对抗损失函数作为判别模型的损失函数对判别模型进行训练,即判别模型的损失函数LD=Ladv。然而对生成模型进行训练时,为了提高生成模型预测的置信度,可以根据对抗损失函数和散度损失函数构建生成模型的损失函数,对生成模型进行训练。其中,散度损失函数用于表示所述标注图像样本的所述第二热力图与第三热力图之间的差距;第三热力图是所述生成模型根据所述标注图像样本生成的。
散度损失函数可以表示为LKL=KL(H||Hl),其中,LKL表示散度损失函数,H表示第三热力图,Hl表示第二热力图。
根据对抗损失函数和散度损失函数构建生成模型的损失函数的方式可以是将散度损失函数与对抗损失函数作差,得到生成模型的损失函数可以表示为LG=LKL-λLadv;其中,LG为所述生成模型的损失函数,LKL为所述散度损失函数,Ladv为所述对抗损失函数,λ为损失权重乘积。
由于在对生成模型进行训练过程中,生成模型的损失函数中引入了散度损失函数,散度损失函数可以体现生成模型预测得到的标注图像样本的第二热力图(标注热力图)与第三热力图(预测热力图)之间的差距,热力图描述一个对应关键点的概率分布,即体现了概率分布信息。因此,根据该损失函数训练得到的生成模型充分考虑了预测值的概率分布信息,从而提高了模型预测的置信度。
基于上述对关键点检测模型训练方法的介绍,接下来将结合实际应用场景对本申请实施例提供的关键点检测模型的训练方法进行介绍。在该应用场景中,关键点检测模型包括生成模型Gθ和判别模型Dφ,生成模型Gθ的网络参数为θ,判别模型Dφ的网络参数为φ。参见图2,所述方法包括:
S201、初始化网络参数θ、φ。
S202、进入循环。
S203、随机抽取小批量标注图像样本和小批量未标注图像样本/>
S204、计算对抗损失函数Ladv
S205、计算散度损失函数LKL
S206、将对抗损失函数判别网络Dφ的损失函数LD
S207、计算生成网络Gθ的损失函数LG
S208、更新判别网络的网络参数Dφ
S209、更新生成网络的网络参数Gθ
S210、判断是否完成训练,若是,则结束,若否则返回S202。
在训练得到关键点检测模型后,可以利用关键点检测模型对输入的待检测图像进行检测,检测得到关键点坐标。本申请提供的关键点检测方法的流程图可以参见图3所示,所述方法包括:
S301、获取待检测图像。
S302、通过生成模型生成热力图。
该生成模型根据图1和图2对应实施例所提供的方法训练得到的,即根据标注图像样本和未标注图像样本,与判别模型进行对抗训练得到的。对抗训练的方式为通过生成模型根据未标注图像样本生成第一热力图;根据标注图像样本中标注的关键点坐标确定第二热力图;通过判别模型计算第一热力图和未标注图像样本的第一匹配度,以及计算第二热力图和标注图像样本的第二匹配度;根据第一匹配度和第二匹配度构建对抗损失函数;根据对抗损失函数对生成模型和判别模型进行训练。
S303、根据所述热力图确定关键点坐标。
在得到热力图后,可以将热力图中热力最高的点确定为关键点,从而得到关键点坐标。在热力图中确定热力最高点的方式可以是利用softargmax操作。
例如,针对一张待检测图像I,将待检测图像I输入到生成模型,生成模型预测的关键点坐标为s=softargmax(γGθ(I)),其中,s表示关键点坐标,softargmax()表示计算极大值函数,γ为温度因子,Gθ(I)为生成模型预测得到的热力图。
基于前述实施例提供的关键点检测模型训练方法,本申请实施例提供一种关键点检测模型训练装置,参见图4,所述装置包括:
生成单元401,用于通过生成模型根据获取的未标注图像样本生成第一热力图;
确定单元402,用于根据获取的标注图像样本中标注的关键点坐标确定第二热力图;
计算单元403,用于通过判别模型计算所述第一热力图和所述未标注图像样本的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;
构建单元404,用于根据所述第一匹配度和所述第二匹配度构建对抗损失函数;
训练单元405,用于根据所述对抗损失函数对所述生成模型和判别模型进行训练。
可选的,所述训练单元,用于:
将所述对抗损失函数作为所述判别模型的损失函数对所述判别模型进行训练;
根据所述对抗损失函数和散度损失函数构建所述生成模型的损失函数,对所述生成模型进行训练;所述散度损失函数用于表示所述标注图像样本的所述第二热力图与第三热力图之间的差距;所述第三热力图是所述生成模型根据所述标注图像样本生成的。
可选的,所述生成模型的损失函数为LG=LKL-λLadv;其中,LG为所述生成模型的损失函数,LKL为所述散度损失函数,Ladv为所述对抗损失函数,λ为损失权重乘积。
可选的,所述确定单元,用于:
根据所述关键点坐标计算均值和均方差;
根据所述均值和均方差计算所述第二热力图。
基于前述实施例提供的关键点检测方法,本申请实施例提供一种关键点检测装置,参见图5,所述装置包括:
获取单元501,用于获取待检测图像;
生成单元502,用于通过生成模型生成热力图;所述生成模型是根据标注图像样本和未标注图像样本,与判别模型进行对抗训练得到的;所述对抗训练的方式为通过生成模型根据所述未标注图像样本生成第一热力图;根据所述标注图像样本中标注的关键点坐标确定第二热力图;通过判别模型计算所述第一热力图和所述未标注图像样本的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;根据所述第一匹配度和所述第二匹配度构建对抗损失函数;根据所述对抗损失函数对所述生成模型和判别模型进行训练;
确定单元503,用于根据所述热力图确定关键点坐标。
本领域普通技术人员可以理解:实现上述方法实施例的全部或部分步骤可以通过程序指令相关的硬件来完成,前述程序可以存储于一计算机可读取存储介质中,该程序在执行时,执行包括上述方法实施例的步骤;而前述的存储介质可以是下述介质中的至少一种:只读存储器(英文:read-only memory,缩写:ROM)、RAM、磁碟或者光盘等各种可以存储程序代码的介质。
需要说明的是,本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于设备及系统实施例而言,由于其基本相似于方法实施例,所以描述得比较简单,相关之处参见方法实施例的部分说明即可。以上所描述的设备及系统实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
以上所述,仅为本申请的一种具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应该以权利要求的保护范围为准。

Claims (10)

1.一种关键点检测模型训练方法,其特征在于,所述方法包括:
通过生成模型根据获取的未标注图像样本生成第一热力图;
根据获取的标注图像样本中标注的关键点坐标确定第二热力图;
通过判别模型计算所述第一热力图和所述未标注图像样本的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;
根据所述第一匹配度和所述第二匹配度构建对抗损失函数;
根据所述对抗损失函数对所述生成模型和判别模型进行训练。
2.根据权利要求1所述的方法,其特征在于,根据所述对抗损失函数对所述生成模型和判别模型进行训练,包括:
将所述对抗损失函数作为所述判别模型的损失函数对所述判别模型进行训练;
根据所述对抗损失函数和散度损失函数构建所述生成模型的损失函数,对所述生成模型进行训练;所述散度损失函数用于表示所述标注图像样本的所述第二热力图与第三热力图之间的差距;所述第三热力图是所述生成模型根据所述标注图像样本生成的。
3.根据权利要求2所述的方法,其特征在于,所述生成模型的损失函数为LG=LKL-λLadv;其中,LG为所述生成模型的损失函数,LKL为所述散度损失函数,Ladv为所述对抗损失函数,λ为损失权重乘积。
4.根据权利要求1所述的方法,其特征在于,所述根据获取的标注图像样本中标注的关键点坐标确定第二热力图,包括:
根据所述关键点坐标计算均值和协方差;
根据所述均值和所述协方差计算所述第二热力图。
5.一种关键点检测方法,其特征在于,所述方法包括:
获取待检测图像;
通过生成模型生成热力图;所述生成模型是根据标注图像样本和未标注图像样本,与判别图像进行对抗训练得到的;所述对抗训练的方式为通过生成模型根据所述未标注图像样本生成第一热力图;根据所述标注图像样本中标注的关键点坐标确定第二热力图;通过判别模型计算所述第一热力图和所述未标注图像的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;根据所述第一匹配度和所述第二匹配度构建对抗损失函数;根据所述对抗损失函数对所述生成模型和判别模型进行训练;
根据所述热力图确定关键点坐标。
6.一种关键点检测模型训练装置,其特征在于,所述装置包括:
生成单元,用于通过生成模型根据获取的未标注图像样本生成第一热力图;
确定单元,用于根据获取的标注图像样本中标注的关键点坐标确定第二热力图;
计算单元,用于通过判别模型计算所述第一热力图和所述未标注图像的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;
构建单元,用于根据所述第一匹配度和所述第二匹配度构建对抗损失函数;
训练单元,用于根据所述对抗损失函数对所述生成模型和判别模型进行训练。
7.根据权利要求6所述的装置,其特征在于,所述训练单元,用于:
将所述对抗损失函数作为所述判别模型的损失函数对所述判别模型进行训练;
根据所述对抗损失函数和散度损失函数构建所述生成模型的损失函数,对所述生成模型进行训练;所述散度损失函数用于表示所述标注图像样本的所述第二热力图与第三热力图之间的差距;所述第三热力图是所述生成模型根据所述标注图像样本生成的。
8.根据权利要求7所述的装置,其特征在于,所述生成模型的损失函数为LG=LKL-λLadv;其中,LG为所述生成模型的损失函数,LKL为所述散度损失函数,Ladv为所述对抗损失函数,λ为损失权重乘积。
9.根据权利要求6所述的装置,其特征在于,所述确定单元,用于:
根据所述关键点坐标计算均值和协方差;
根据所述均值和所述协方差计算所述第二热力图。
10.一种关键点检测装置,其特征在于,所述装置包括:
获取单元,用于获取待检测图像;
生成单元,用于通过生成模型生成热力图;所述生成模型是根据标注图像样本和未标注图像样本,与判别图像进行对抗训练得到的;所述对抗训练的方式为通过生成模型根据所述未标注图像样本生成第一热力图;根据所述标注图像样本中标注的关键点坐标确定第二热力图;通过判别模型计算所述第一热力图和所述未标注图像的第一匹配度,以及计算所述第二热力图和所述标注图像样本的第二匹配度;根据所述第一匹配度和所述第二匹配度构建对抗损失函数;根据所述对抗损失函数对所述生成模型和判别模型进行训练;
确定单元,用于根据所述热力图确定关键点坐标。
CN202010294788.7A 2020-04-15 2020-04-15 一种关键点检测模型训练方法、关键点检测方法和装置 Active CN111523422B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010294788.7A CN111523422B (zh) 2020-04-15 2020-04-15 一种关键点检测模型训练方法、关键点检测方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010294788.7A CN111523422B (zh) 2020-04-15 2020-04-15 一种关键点检测模型训练方法、关键点检测方法和装置

Publications (2)

Publication Number Publication Date
CN111523422A CN111523422A (zh) 2020-08-11
CN111523422B true CN111523422B (zh) 2023-10-10

Family

ID=71904091

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010294788.7A Active CN111523422B (zh) 2020-04-15 2020-04-15 一种关键点检测模型训练方法、关键点检测方法和装置

Country Status (1)

Country Link
CN (1) CN111523422B (zh)

Families Citing this family (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111985556A (zh) * 2020-08-19 2020-11-24 南京地平线机器人技术有限公司 关键点识别模型的生成方法和关键点识别方法
CN111967406A (zh) * 2020-08-20 2020-11-20 高新兴科技集团股份有限公司 人体关键点检测模型生成方法、系统、设备和存储介质
CN112101490B (zh) * 2020-11-20 2021-03-02 支付宝(杭州)信息技术有限公司 热力图转换模型训练方法以及装置
CN112818809B (zh) * 2021-01-25 2022-10-11 清华大学 一种检测图像信息的方法、装置和存储介质
CN113128436B (zh) * 2021-04-27 2022-04-01 北京百度网讯科技有限公司 关键点的检测方法和装置
CN113569627A (zh) * 2021-06-11 2021-10-29 北京旷视科技有限公司 人体姿态预测模型训练方法、人体姿态预测方法及装置
CN113706463B (zh) * 2021-07-22 2024-04-26 杭州键嘉医疗科技股份有限公司 基于深度学习的关节影像关键点自动检测方法、装置
CN113822254B (zh) * 2021-11-24 2022-02-25 腾讯科技(深圳)有限公司 一种模型训练方法及相关装置

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108133220A (zh) * 2016-11-30 2018-06-08 北京市商汤科技开发有限公司 模型训练、关键点定位及图像处理方法、系统及电子设备
CN108229489A (zh) * 2016-12-30 2018-06-29 北京市商汤科技开发有限公司 关键点预测、网络训练、图像处理方法、装置及电子设备
CN109508681A (zh) * 2018-11-20 2019-03-22 北京京东尚科信息技术有限公司 生成人体关键点检测模型的方法和装置
CN110110745A (zh) * 2019-03-29 2019-08-09 上海海事大学 基于生成对抗网络的半监督x光图像自动标注
CN110210624A (zh) * 2018-07-05 2019-09-06 第四范式(北京)技术有限公司 执行机器学习过程的方法、装置、设备以及存储介质
CN110263845A (zh) * 2019-06-18 2019-09-20 西安电子科技大学 基于半监督对抗深度网络的sar图像变化检测方法
CN110298415A (zh) * 2019-08-20 2019-10-01 视睿(杭州)信息科技有限公司 一种半监督学习的训练方法、系统和计算机可读存储介质
CN110335337A (zh) * 2019-04-28 2019-10-15 厦门大学 一种基于端到端半监督生成对抗网络的视觉里程计的方法
CN110751097A (zh) * 2019-10-22 2020-02-04 中山大学 一种半监督的三维点云手势关键点检测方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US8831358B1 (en) * 2011-11-21 2014-09-09 Google Inc. Evaluating image similarity

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108133220A (zh) * 2016-11-30 2018-06-08 北京市商汤科技开发有限公司 模型训练、关键点定位及图像处理方法、系统及电子设备
CN108229489A (zh) * 2016-12-30 2018-06-29 北京市商汤科技开发有限公司 关键点预测、网络训练、图像处理方法、装置及电子设备
CN110210624A (zh) * 2018-07-05 2019-09-06 第四范式(北京)技术有限公司 执行机器学习过程的方法、装置、设备以及存储介质
CN109508681A (zh) * 2018-11-20 2019-03-22 北京京东尚科信息技术有限公司 生成人体关键点检测模型的方法和装置
CN110110745A (zh) * 2019-03-29 2019-08-09 上海海事大学 基于生成对抗网络的半监督x光图像自动标注
CN110335337A (zh) * 2019-04-28 2019-10-15 厦门大学 一种基于端到端半监督生成对抗网络的视觉里程计的方法
CN110263845A (zh) * 2019-06-18 2019-09-20 西安电子科技大学 基于半监督对抗深度网络的sar图像变化检测方法
CN110298415A (zh) * 2019-08-20 2019-10-01 视睿(杭州)信息科技有限公司 一种半监督学习的训练方法、系统和计算机可读存储介质
CN110751097A (zh) * 2019-10-22 2020-02-04 中山大学 一种半监督的三维点云手势关键点检测方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
"Semi-Supervised Learning Based on Generative Adversarial Network and Its Applied to Lithology Recognition";Guohe Li 等;《IEEE Access》;20190522;第7卷;第67428-67436页 *
"基于自适应对抗学习的半监督图像语义分割";张桂梅 等;《南昌航空大学学报(自然科学版)》;20190915;第33卷(第3期);第33-40页 *

Also Published As

Publication number Publication date
CN111523422A (zh) 2020-08-11

Similar Documents

Publication Publication Date Title
CN111523422B (zh) 一种关键点检测模型训练方法、关键点检测方法和装置
Yang et al. Uncertainty-guided transformer reasoning for camouflaged object detection
Yuan et al. Robust visual tracking with correlation filters and metric learning
Ozay et al. Machine learning methods for attack detection in the smart grid
Chen et al. Part-activated deep reinforcement learning for action prediction
WO2020061489A1 (en) Training neural networks for vehicle re-identification
Du et al. Online deformable object tracking based on structure-aware hyper-graph
Kuhnke et al. Deep head pose estimation using synthetic images and partial adversarial domain adaption for continuous label spaces
CN110147699B (zh) 一种图像识别方法、装置以及相关设备
US10592786B2 (en) Generating labeled data for deep object tracking
Filtjens et al. Skeleton-based action segmentation with multi-stage spatial-temporal graph convolutional neural networks
US9600897B2 (en) Trajectory features and distance metrics for hierarchical video segmentation
CN110163060B (zh) 图像中人群密度的确定方法及电子设备
Yarkony et al. Data association via set packing for computer vision applications
Mehrkanoon et al. Incremental multi-class semi-supervised clustering regularized by Kalman filtering
CN111611395B (zh) 一种实体关系的识别方法及装置
US20230252271A1 (en) Electronic device and method for processing data based on reversible generative networks, associated electronic detection system and associated computer program
Lin et al. Region-based context enhanced network for robust multiple face alignment
CN116665282A (zh) 人脸识别模型训练方法、人脸识别方法及装置
Han et al. Cultural and creative product design and image recognition based on the convolutional neural network model
CN114462526A (zh) 一种分类模型训练方法、装置、计算机设备及存储介质
CN115222047A (zh) 一种模型训练方法、装置、设备及存储介质
Xue et al. Towards gene function prediction via multi-networks representation learning
Liu et al. An improved dual-channel network to eliminate catastrophic forgetting
Wu et al. Fragmentary multi-instance classification

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant