CN114358205A - 模型训练方法、模型训练装置、终端设备及存储介质 - Google Patents

模型训练方法、模型训练装置、终端设备及存储介质 Download PDF

Info

Publication number
CN114358205A
CN114358205A CN202210031145.2A CN202210031145A CN114358205A CN 114358205 A CN114358205 A CN 114358205A CN 202210031145 A CN202210031145 A CN 202210031145A CN 114358205 A CN114358205 A CN 114358205A
Authority
CN
China
Prior art keywords
image
loss value
recognition
loss
trained
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210031145.2A
Other languages
English (en)
Inventor
司世景
王健宗
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202210031145.2A priority Critical patent/CN114358205A/zh
Publication of CN114358205A publication Critical patent/CN114358205A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本申请适用于人工智能技术领域,提供了一种模型训练方法、模型训练装置、终端设备及存储介质。该方法包括:针对样本集中的每个样本图像执行两种不同的数据增广操作,得到第一图像和第二图像;基于待训练的重识别模型对第一图像和第二图像进行行人重识别得到重识别预测结果;基于重识别预测结果计算待训练的重识别模型的第一损失值、第二损失值以及第三损失值;基于第一损失值、第二损失值、第三损失值及预设的目标损失函数计算得到待训练的重识别模型的总损失值,并基于总损失值对待训练的重识别模型优化,直至目标损失函数收敛,得到训练完成的重识别模型。该方法得到的重识别模型,具备较强的鲁棒性和泛化能力。此外,本申请还涉及区块链技术。

Description

模型训练方法、模型训练装置、终端设备及存储介质
技术领域
本申请涉及人工智能技术领域,尤其涉及一种模型训练方法、模型训练装置、终端设备及计算机可读存储介质。
背景技术
行人重识别(person re-identification,Re-ID)是一项跨设备查找特定行人的图像检索技术,旨在利用计算机视觉技术检索出非重叠摄像机下的特定行人图像。在模型训练任务中,由于拍摄场景的复杂性,行人图像受拍摄角度、光照以及姿态等因素影响较大,使得在不同场景,甚至同一场景下不同的摄像头获取的图像风格差异较大;加之,在同一个数据集中,难免存在类间相似度高和类内差异小的问题,导致训练得到的模型识别准确性偏低且泛化能力较差。
目前,为了提升模型的鲁棒性和泛化能力,大部分行人重识别算法往往基于设计复杂的网络结构来提高模型的表征能力,但是该方法提高模型表征能力的同时降低了算法效率且增加了应用难度。因此,如何提供一种简单且便于应用的模型训练方法是当前亟待解决的问题。
发明内容
有鉴于此,本申请实施例提供了一种模型训练方法、模型训练装置、终端设备及计算机可读存储介质,通过改变模型训练方法,从而提高重识别模型的鲁棒性和泛化能力。
本申请实施例的第一方面提供了一种模型训练方法,包括:
针对样本集中的每个样本图像,对上述样本图像执行两种不同的数据增广操作,得到第一图像和第二图像;
基于待训练的重识别模型分别对上述第一图像和上述第二图像进行行人重识别得到重识别预测结果,上述重识别模型包括主干网络、平均池化层以及全连接层;
基于上述重识别预测结果计算上述待训练的重识别模型的第一损失值、第二损失值以及第三损失值,上述第一损失值用于计算图像所属标签的真实概率和预测概率之间的损失,上述第二损失值用于计算图像的真实区别特征与提取区别特征之间的损失,上述第三损失值用于计算图像的真实标签与预测标签之间的损失;
基于上述第一损失值、上述第二损失值、上述第三损失值及预设的目标损失函数计算得到上述待训练的重识别模型的总损失值,并基于上述总损失值对上述待训练的重识别模型进行优化,直至上述目标损失函数收敛,得到训练完成的重识别模型。
本申请实施例的第二方面提供了一种模型训练装置,包括:
数据增广模块,用于针对样本集中的每个样本图像,对上述样本图像执行两种不同的数据增广操作,得到第一图像和第二图像;
重识别模块,用于基于待训练的重识别模型分别对上述第一图像和上述第二图像进行行人重识别得到重识别预测结果,上述重识别模型包括主干网络、平均池化层以及全连接层;
计算模块,用于基于上述重识别预测结果计算上述待训练的重识别模型的第一损失值、第二损失值以及第三损失值,上述第一损失值用于描述图像所属标签的真实概率和预测概率之间的损失,上述第二损失值用于描述图像的真实区别特征与提取区别特征之间的损失,上述第三损失值用于描述图像的真实标签与预测标签之间的损失;
优化模块,用于基于上述第一损失值、上述第二损失值、上述第三损失值及预设的目标损失函数计算得到上述待训练的重识别模型的总损失值,并基于上述总损失值对上述待训练的重识别模型进行优化,直至上述目标损失函数收敛,得到训练完成的重识别模型。
本申请实施例的第三方面提供了一种终端设备,包括存储器、处理器以及存储在上述存储器中并可在终端设备上运行的计算机程序,上述处理器执行上述计算机程序时实现第一方面提供的模型训练方法的各步骤。
本申请实施例的第四方面提供了一种计算机可读存储介质,上述计算机可读存储介质存储有计算机程序,上述计算机程序被处理器执行时实现第一方面提供的模型训练方法的各步骤。
实施本申请实施例提供的一种模型训练方法、模型训练装置、终端设备及计算机可读存储介质具有以下有益效果:
通过对样本集中的每个样本图像执行两种不同的数据增广操作,可以得到第一图像和第二图像,基于待训练的重识别模型分别对第一图像和第二图像进行行人重识别得到重识别预测结果,重识别模型包括主干网络、平均池化层以及全连接层;基于重识别预测结果计算待训练的重识别模型的第一损失值、第二损失值以及第三损失值,第一损失值用于描述图像所属标签的真实概率和预测概率之间的损失,第二损失值用于描述图像的真实区别特征与提取区别特征之间的损失,第三损失值用于描述图像的真实标签与预测标签之间的损失;基于第一损失值、第二损失值、第三损失值及预设的目标损失函数计算得到待训练的重识别模型的总损失值,并基于总损失值对待训练的重识别模型进行优化,直至目标损失函数收敛,得到训练完成的重识别模型。该模型训练方法从行人重识别任务中的风格偏差问题出发,基于对样本图像的数据增广操作,促使模型在训练过程中能够学习不同风格的图像中不变的本质辨别性特征,从而提高模型的表征能力;通过三种损失与预设的目标损失函数对模型进行优化,能够让训练得到的重识别模型具备较强的鲁棒性和泛化能力。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的一种模型训练方法的实现流程图;
图2是本申请实施例提供的一种重识别模型的结构框图;
图3是本申请实施例提供的一种模型训练装置的结构框图;
图4是本申请实施例提供的一种终端设备的结构框图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
本申请实施例所涉及的模型训练方法,可以由终端设备,例如笔记本电脑、超级移动个人计算机(ultra-mobile personal computer,UMPC)、上网本或个人数字助理(personal digital assistant,PDA)执行。
本申请实施例可以基于人工智能技术对相关的数据进行获取和处理。其中,人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。
人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统及机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、机器人技术、生物识别技术、语音处理技术及自然语言处理技术以及机器学习/深度学习等几大方向。
本申请实施例涉及的模型训练方法,可以应用于智慧交通的场景中,能够推动智慧城市的建设。
请参阅图1,图1示出了本申请实施例提供的一种模型训练方法的实现流程图。该确定方法包括:
步骤110、针对样本集中的每个样本图像,对样本图像执行两种不同的数据增广操作,得到第一图像和第二图像。
为了让训练完成的重识别模型能够识别不同风格的图像中不变的特征,可以对样本集中的每个样本图像执行两种不同的数据增广操作,从而得到风格不同的第一图像和第二图像。举例来说,对于某个样本图像,可以执行旋转操作后,得到第一图像;还可以执行平移后,得到第二图像。后续基于第一图像和第二图像对待训练的重识别模型进行训练,能够促使重识别模型分别从第一图像和第二图像中学习本质的辨别性特征,从而提高模型的表征能力。其中,样本集是训练模型的学习样本数据集,该样本集中的图像可以通过网上爬取的方式获得,例如从ImageNet这一数据库中爬取。
步骤120、基于待训练的重识别模型分别对第一图像和第二图像进行行人重识别得到重识别预测结果。
在得到第一图像和第二图像之后,即可将第一图像和第二图像输入待训练的重识别模型中,以实现行人重识别,得到重识别预测结果。其中,重识别模型包括主干网络、平均池化层以及全连接层。
步骤130、基于重识别预测结果计算待训练的重识别模型的第一损失值、第二损失值以及第三损失值。
在得到重识别结果后,可以基于重识别结果得到待训练的重识别模型的三个损失,分别对应第一损失值、第二损失值以及第三损失值,其中第一损失值用于描述图像所属标签的真实概率和预测概率之间的损失,第二损失值用于描述图像的真实区别特征与提取区别特征之间的损失,第三损失值用于描述图像的真实标签与预测标签之间的损失。
步骤140、基于第一损失值、第二损失值、第三损失值及预设的目标损失函数计算得到待训练的重识别模型的总损失值,并基于总损失值对待训练的重识别模型进行优化,直至目标损失函数收敛,得到训练完成的重识别模型。
在得到第一损失值、第二损失值以及第三损失值之后,可以根据预设的目标损失函数,计算待训练的重识别模型的总损失值,并基于总损失值对待训练的重识别模型进行优化。优化后的重识别模型,不确定其是否能够达到预期的性能,因此可以返回执行步骤110-140步骤以检验当前的重识别模型是否需要再次被优化。具体地,在返回执行的过程中,当再次执行140步骤时,可以将总损失值与设定的损失阈值进行比较,如果总损失值小于损失阈值,则说明目标损失函数收敛,即此时得到的重识别模型即为训练完成的重识别模型;但如果总损失值大于损失阈值,则说明目标损失函数未收敛,需要再次基于当前的总损失值对待训练的重识别模型进行优化,直至总损失之小于损失阈值,得到训练完成的重识别模型。
该模型训练方法先从行人重识别任务中的风格偏差问题出发,基于对样本图像的数据增广操作,促使模型在训练过程中能够学习样本图像的本质辨别性特征,从而提高模型的表征能力;然后通过三种损失与预设的目标损失函数对模型进行优化,能够让训练得到的重识别模型具备较强的鲁棒性和泛化能力。
在一些实施例中,为了增强训练完成的模型的泛化能力且提高其识别的准确性,数据增广操作可以包括水平/垂直翻转、旋转、缩放、裁剪、剪切、平移、对比度、色彩抖动以及噪声等。
数据增广是深度学习中常用的技巧之一,主要用于增加训练数据集,让数据集尽可能的多样化,使得训练的模型具有更强的泛化能力。而在本申请中,不仅仅是为了让训练完成的重识别模型具备更强的泛化能力,还为了通过数据增广操作提高重识别模型的识别准确性。这是因为在复杂的拍摄场景中,行人图像受摄角度、光照及姿态等因素影响较大,导致同一行人的多张图像风格差异加大,进而导致模型的识别准确度降低。为了解决这一问题,本申请实施例通过对样本图像执行两种不同的数据增广操作,模拟出两张风格不同但所指向的行人相同的图像,即第一图像和第二图像,之后利用第一图像和第二图像对待训练的重识别模型进行训练,能够让训练得到的重识别模型能够学习行人图像本质的辨别性特征。因此即使在复杂的拍摄场景中,训练完成的重识别模型也能够从风格各异的行人图像中准确识别出特定的行人,从而提高行人重识别的准确性。
在一些实施例中,参阅图2,待训练的重识别模型是迁移Imagenet上预训练的ResNet50,其中ResNet50移除平均池化层和全连接层后所剩余的部分即为主干网络。具体地,上述步骤120具体包括:
步骤121、基于主干网络分别对第一图像和第二图像进行特征提取,得到第一特征和第二特征。
步骤122、基于平均池化层将第一特征和第二特征向量化,得到第一特征向量和第二特征向量。
步骤123、基于全连接层、第一特征向量以及第二特征向量进行行人重识别,得到重识别预测结果。
将第一图像及第二图像中的每个图像输入主干网络,以执行特征提取操作,能够为待训练的重识别模型进行重识别提供信息和冗余的派生值,从而促进该待训练的重识别模型后续的学习和泛化步骤。在某些情况下,特征提取能为识别过程带来可解释性。对于图像而言,其本身属于二维数据,因此特征提取可以理解为一种降维方式。在得到第一特征和第二特征后,可以将这两个特征输入平均池化层,以将第一特征和第二特征向量化,得到第一特征向量和第二特征向量。之所以要对两个特征执行向量化操作,不仅是为了提升待训练的重识别模型的识别效率,而且还为了小化特征图大小,进而减少计算量和所需显存,从而提高模型训练的效率。在得到第一特征向量和第二特征向量后,可以将这两个特征向量输入全连接层中,以整合池化层中具有类别区分性的局部信息,实现行人重识别,得到重识别预测结果。可以理解的是,对于行人重识别这一分类任务,其可以是单分类任务也可以是多分类任务:如果重识别对象为某一个人,那么该重识别任务即为单分类任务;但如果重识别对象为某几个人,则该重识别任务为多分类任务。
本申请实施例所设计的重识别模型,可以视作一个基线网络,不仅具有很强的特征表示能力和泛化能力,还能应用于不同领域的数据场景。同时,也可以将该基线网络衔接到其他优秀的行人重识别算法模型中,能有效提高识别精度。
在一些实施例中,重识别预测结果可以包括第一图像的第一预测概率以及第二图像的第二预测概率,具体地,上述第一损失值可以通过以下步骤确定:
基于KL散度及预设的第一损失函数对第一预测概率和第二预测概率进行约束,得到第一损失值,第一损失函数为:
Figure BDA0003466489360000081
其中,p1为第一图像的第一预测概率,p2为第二图像的第二预测概率,DKL(p1||p2)为第一图像与第二图像的相对熵,DKL(p2||p1)为第二图像与第一图像的相对熵。由于相对熵是不对称的,因此在本申请实施例中采用了两个相对熵进行计算。并且在本申请中,第一损失值由两个部分组成,一部分是基于第一图像计算得到的第一损失值,另一部分是基于第二图像计算得到的第二损失值,其中,第一图像指的是样本集中所有图像执行某种数据增广操作得到的第一图像,并非是单张图像,对于第二图像类似。也就是说,第一损失值是在处理完样本集中的每张图像后进行计算得到的。
由于这第一图像和第二图像是执行不同的数据增广操作后得到的结果,因此属于不同的视图,将这两个图像经过同一个重识别模型进行预测,可以得到两个不同的概率。但由于第一图像和第二图像是经过同一样本图像得到的,因此这两个图像对应的是同一个行人。也就是说,如果模型能够准确识别第一图像和第二图像,那么得到的预测概率应该是相近的,利用两个图片的预测概率应该相近这一原则,可以提高重识别模型识别的准确性。具体地,可以采用对称KL散度对两个预测概率进行约束,从而让重识别模型略过两个图像中的风格信息,学习与行人身份相关的特征信息,以提高重识别模型识别的准确性和泛化性。
在一些实施例中,为了让重识别模型得到更好的表征,可以结合难样本采样三元组损失(Triplet loss with batch hard mining,TriHard loss)对得到的两个特征向量进行约束。重识别预测结果包括正负样本对距离,第二损失值通过正负样本对距离和预设的第二损失函数计算,正负样本对距离通过以下步骤获得:
210、针对第一图像及第二图像中的每种图像,将图像分为至少1个批次,每个批次包括P个行人标签,每个行人标签对应有K张图像,每个行人标签对应的K张图像中,包括一张固定图像a。
220、针对每个批次中的每张固定图像a,将与固定图像a的行人标签相同的图像确定为正样本集A,将与固定图像a的行人标签不相同的图像确定为负样本集B,基于固定图像a、正样本集A以及负样本集B计算正负样本对距离。
难样本采样三元组损失是三元组损失的改进版本,传统的三元祖损失是随机从样本集中抽取三张图像,以计算正、负样本对距离,但是该方法抽取到的大部分样本都是能够简易区分的样本,如果大量训练的样本对都是简单的样本对,不利于模型学习到更好的表征。因此本申请实施例采用难样本采样的方式,以获得不利于区分的样本对,基于这样的样本对待训练的重识别模型进行训练,能够提高重识别模型的表征能力。
具体地,本申请实施例中有两种图像,即所有第一图像可以作为一种图像,所有第二图像可以作为另一种图像。可以按批次从每种图像中抽取图像,并且针对一个批次,其包含P*K张图像,P为行人标签,K为每个行人标签对应的图像数量,行人标签相同的K张图像中,包含了一张固定图像a,该固定图像a可以理解为准确预测行人标签的基准图像。由于在步骤110中,一张图像经过两种不同的数据增广操作得到了两张不同风格的图像,即第一图像和第二图像,因此在按照批次采样的过程中,可以将所有第一图像作为一个待采样集合,将所有第二图像作为另一个待采样集合,在采样的过程中,分别对每个待采样集合执行上述的批次采样。
在完成采样后,针对每个批次,可以基于该批次中的每张固定图像a划分出p组正、负样本集,即将与固定图像a的行人标签相同的图像确定为正样本集A,将与固定图像a的行人标签不相同的图像确定为负样本集B。
为了便于理解,举例说明,假设第一个批次中,有3个行人标签,即张三、李四和王五,其中,每个行人标签对应有5张图像,第一批次中有15张图像,行人标签为张三的图像有5张,行人标签为李四的图像有5张,行人标签为王五的图像有5张。行人标签为张三的5张图像中有一张固定图像a1,行人标签为李四的5张图像中有一张固定图像a2,行人标签为王五的5张图像中有一张固定图像a3,针对a1、a2以及a3三张固定图像中的每张固定图像,可以确定出一组正负样本集,即针对a1,可以与其行人标签相同的另外4张图像组成正样本集A1,将于其行人标签不相同的两外10张图像组成负样本集B1,与其行人标签不相同的两外10张图像也即第一批次中行人标签为李四和王五的10张图像。对于a2和a3,以上述a1类推。通过上述操作,可以得到三组正负样本集。且针对每一组正负样本集,可以结合对应的固定图像a计算正负样本对距离。
可选地,可以利用欧式距离、马氏距离以及余弦相似度等表征两个图像之间的距离。
在一些实施例中,正负样本对距离包括正样本对距离和负样本对距离,基于固定图像a、正样本集A以及负样本集B计算正负样本对距离具体包括:
221、计算正样本集A中每张图像与固定图像a之间的正样本对距离。
222计算负样本集B中每张图像与固定图像a之间的负样本对距离。
以上述固定图像a1、正样本集A1以及正样本集B1为例,可以计算a1与正样本集A1中每个图像之间的距离,得到4个正样本对距离,计算a1与正样本集B1中每个图像之间的距离,得到10个负样本对距离。
可选地,在计算出正负样本对距离之后,为了混淆特征提取器,可以从其中挑选出最难的正样本和最难的负样本,即将4个正样本对距离中的最大值对应的图像作为正样本图像,将10个负样本对距离中的最小值对应的图像作为负样本图像,从而组成难样本采样三元组,以便于计算第二损失值。
基于该难样本采样三元组损失对待训练的重识别模型进行优化,能够拉近正样本对之间的距离,推离负样本对之间的距离,使得数据集中类内距离更近,类间距离更远,从而提高重识别模型的识别能力。
具体地,第二损失函数的公式如下:
Figure BDA0003466489360000111
其中,p为正样本,n为负样本,da,p和da,n分别表示正、负样本对的距离,(·)+表示max(·,0)。应当理解的是,对于每种图像,均采用上述公式计算每种图像的第二损失值,在计算完两种图像的第二损失值后,可以进一步计算总的第二损失值。
在一些实施例中,上述第三损失值通过预设的第三损失函数计算得到,第三损失函数为:
Figure BDA0003466489360000112
其中,N为样本集中行人标签的类别总数,qi为预测标签,pi为图像中的行人属于第i类行人标签的预测概率,y为图像中的行人的真实标签。和第二损失值的计算相同,对于每种图像,均采用上述公式计算每种图像的第三损失值,在计算完两种图像的第三损失值后,可以进一步计算总的第三损失值。
同样的,为了保持三个损失函数计算的一致性,第三损失值包括有两个部分组成,即基于第一图像计算得到的第三损失值和基于第二图像计算得到的第三损失值。
基于上述三种损失值,即可计算出总损失值,其中,计算总损失值所使用的目标损失函数如下所示:
Figure BDA0003466489360000113
其中,LKL为基于第一损失值,所述
Figure BDA0003466489360000121
为基于第一图像计算得到的第二损失值,
Figure BDA0003466489360000122
为基于第二图像计算得到的第二损失值,
Figure BDA0003466489360000123
为基于第一图像计算得到的第三损失值,
Figure BDA0003466489360000124
为基于第二图像计算得到的第三损失值。也就是说,上述所计算的总损失值,是在对所有的第一图像和所有第二图像进行识别后计算得到的。例如,样本图像有1000张,经过数据增广操作,得到1000张第一图像和1000张第二图像,当对2000张图像均实现识别后,才会根据上述总目标损失函数计算总损失值,以便于对重识别模型进行优化。
在一些实施例中,在上述步骤140之后,上述模型训练方法还包括:
将上述样本集、待训练的重识别模型和/或训练完成的重识别模型上传至区块链(Blockchain)中。
其中,为了保证数据的安全性和对用户的公正透明性,可以将本集、待训练的重识别模型和/或训练完成的重识别模型上传至区块链进行存证。用户随后即可通过各自的设备从区块链中下载获得本集、待训练的重识别模型和/或训练完成的重识别模型,以便查证这些数据是否被篡改。本实施例所指区块链是采用分布式数据存储、点对点传输、共识机制及加密算法等计算机技术的新型应用模式。区块链,本质上是一个去中心化的数据库,是一串使用密码学方法相关联产生的数据块,每一个数据块中包含了一批次网络交易的信息,用于验证其信息的有效性(防伪)和生成下一个区块。区块链可以包括区块链底层平台、平台产品服务层以及应用服务层等。
此外,本申请实施例还提供了一种模型训练装置。
请参阅图3,图3是本申请实施例提供的一种模型训练装置的结构框图。本实施例中该终端设备包括的各单元用于执行图1对应的实施例中的各步骤。具体请参阅图1以及图1所对应的实施例中的相关描述。为了便于说明,仅示出了与本实施例相关的部分。参见图3,模型训练装置30包括:
数据增广模块31,用于针对样本集中的每个样本图像,对样本图像执行两种不同的数据增广操作,得到第一图像和第二图像;
重识别模块32,用于基于待训练的重识别模型分别对第一图像和第二图像进行行人重识别得到重识别预测结果,重识别模型包括主干网络、平均池化层以及全连接层;
计算模块33,用于基于重识别预测结果计算待训练的重识别模型的第一损失值、第二损失值以及第三损失值,第一损失值用于计算图像所属标签的真实概率和预测概率之间的损失,第二损失值用于计算图像的真实区别特征与提取区别特征之间的损失,第三损失值用于计算图像的真实标签与预测标签之间的损失;
优化模块34,用于基于第一损失值、第二损失值、第三损失值及预设的目标损失函数计算得到待训练的重识别模型的总损失值,并基于总损失值对待训练的重识别模型进行优化,直至目标损失函数收敛,得到训练完成的重识别模型。
作为本申请一实施例,上述重识别预测结果包括第一图像的第一预测概率以及第二图像的第二预测概率,上述计算模块33具体用于:
基于KL散度及预设的第一损失函数对第一预测概率和第二预测概率进行约束,得到第一损失值,第一损失函数为:
Figure BDA0003466489360000131
其中,p1为第一图像的第一预测概率,p2为第二图像的第二预测概率,DKL(p1||p2)为第一图像的相对熵,DKL(p2||p1)为第二图像的相对熵。
作为本申请一实施例,重识别预测结果包括正负样本对距离,上述计算模块33具体用于通过正负样本对距离和预设的第二损失函数计算第二损失值,上述模型训练装置还包括:
采样模块,用于针对第一图像及第二图像中的每个图像,将图像分为至少1个批次,每个批次包括P个行人标签,每个行人标签对应有K张图像,每个行人标签对应的K张图像中,包括一张固定图像a;
第一确定模块,用于针对每个批次中的每张固定图像a,将与固定图像a的行人标签相同的图像确定为正样本集A,将与固定图像a的行人标签不相同的图像确定为负样本集B,基于固定图像a、正样本集A以及负样本集B计算正负样本对距离。
作为本申请一实施例,上述正负样本对距离包括正样本对距离和负样本对距离,上述第一确定模块可以包括:
第一计算单元,用于计算正样本集A中每张图像与固定图像a之间的正样本对距离;
第二计算单元,用于计算负样本集B中每张图像与固定图像a之间的负样本对距离。
作为本申请一实施例,上述计算模块33具体用于通过预设的第三损失函数计算第三损失值,第三损失函数为:
Figure BDA0003466489360000141
其中,N为样本集中行人标签的类别总数,qi为预测标签,pi为图像中的行人属于第i类行人标签的预测概率,y为图像中的行人的真实标签。
作为本申请一实施例,重识别模块32还包括:
特征提取单元,用于基于主干网络分别对第一图像和第二图像进行特征提取,得到第一特征和第二特征;
向量化单元,用于基于平均池化层将第一特征和第二特征向量化,得到第一特征向量和第二特征向量;
重识别单元,用于基于全连接层、第一特征向量以及第二特征向量进行行人重识别,得到重识别预测结果。
作为本申请一实施例,上述模型训练装置30还包括:
上传模块,用于将样本集、待训练的重识别模型和/或训练完成的重识别模型上传至区块链中。
应当理解的是,图3示出的模型训练装置的结构框图中,各单元用于执行图1对应的实施例中的各步骤,而对于图1对应的实施例中的各步骤已在上述实施例中进行详细解释,具体请参阅图1以及图1所对应的实施例中的相关描述,此处不再赘述。
图4是本申请另一实施例提供的一种终端设备的结构框图。如图4所示,该实施例的终端设备40包括:处理器41、存储器42以及存储在上述存储器42中并可在上述处理器41上运行的计算机程序43,例如模型训练方法的程序。处理器41执行上述计算机程序43时实现上述各个模型训练方法各实施例中的步骤,例如图1所示的110至140。或者,上述处理器41执行上述计算机程序43时实现上述图3对应的实施例中各模块的功能,例如,图3所示的模块31至34的功能,具体请参阅图3对应的实施例中的相关描述,此处不赘述。
示例性的,上述计算机程序43可以被分割成一个或多个单元,上述一个或者多个单元被存储在上述存储器42中,并由上述处理器41执行,以完成本申请。上述一个或多个单元可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述上述计算机程序43在上述终端40中的执行过程。例如,上述计算机程序43可以被分割成数据增广模块31、重识别模块32、计算模块33以及优化模块34,各模块具体功能如上。
上述终端设备可包括,但不仅限于,处理器41、存储器42。本领域技术人员可以理解,图4仅仅是终端设备40的示例,并不构成对终端设备40的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如上述终端设备还可以包括输入输出设备、网络接入设备、总线等。
所称处理器41可以是中央处理单元(Central Processing Unit,CPU),还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
上述存储器42可以是上述终端设备40的内部存储单元,例如终端设备40的硬盘或内存。上述存储器42也可以是上述终端设备40的外部存储设备,例如上述终端设备40上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。进一步地,上述存储器42还可以既包括上述终端设备40的内部存储单元也包括外部存储设备。上述存储器42用于存储上述计算机程序以及上述终端设备所需的其他程序和数据。上述存储器42还可以用于暂时地存储已经输出或者将要输出的数据。
上述实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。

Claims (10)

1.一种模型训练方法,其特征在于,所述模型训练方法包括:
针对样本集中的每个样本图像,对所述样本图像执行两种不同的数据增广操作,得到第一图像和第二图像;
基于待训练的重识别模型分别对所述第一图像和所述第二图像进行行人重识别得到重识别预测结果,所述重识别模型包括主干网络、平均池化层以及全连接层;
基于所述重识别预测结果计算所述待训练的重识别模型的第一损失值、第二损失值以及第三损失值,所述第一损失值用于描述图像所属标签的真实概率和预测概率之间的损失,所述第二损失值用于描述图像的真实区别特征与提取区别特征之间的损失,所述第三损失值用于描述图像的真实标签与预测标签之间的损失;
基于所述第一损失值、所述第二损失值、所述第三损失值及预设的目标损失函数计算得到所述待训练的重识别模型的总损失值,并基于所述总损失值对所述待训练的重识别模型进行优化,直至所述目标损失函数收敛,得到训练完成的重识别模型。
2.根据权利要求1所述的模型训练方法,其特征在于,所述重识别预测结果包括所述第一图像的第一预测概率以及所述第二图像的第二预测概率,所述第一损失值通过如下方式计算得到:
基于KL散度及预设的第一损失函数对所述第一预测概率和所述第二预测概率进行约束,得到所述第一损失值,所述第一损失函数为:
Figure FDA0003466489350000011
其中,所述p1为第一图像的第一预测概率,所述p2为第二图像的第二预测概率,所述DKL(p1||p2)为所述第一图像的相对熵,所述DKL(p2||p1)为所述第二图像的相对熵。
3.根据权利要求1所述的模型训练方法,其特征在于,所述重识别预测结果包括正负样本对距离,所述第二损失值通过所述正负样本对距离和预设的第二损失函数计算,所述正负样本对距离通过以下步骤获得:
针对第一图像及第二图像中的每种图像,将所述图像分为至少1个批次,每个批次包括P个行人标签,每个行人标签对应有K张图像,每个所述行人标签对应的K张图像中,包括一张固定图像a;
针对每个批次中的每张固定图像a,将与所述固定图像a的行人标签相同的所述图像确定为正样本集A,将与所述固定图像a的行人标签不相同的所述图像确定为负样本集B,基于所述固定图像a、所述正样本集A以及所述负样本集B计算所述正负样本对距离。
4.根据权利要求3所述的模型训练方法,其特征在于,所述正负样本对距离包括正样本对距离和负样本对距离,所述基于所述固定图像a、所述正样本集A以及所述负样本集B计算所述正负样本对距离包括:
计算正样本集A中每张所述图像与所述固定图像a之间的正样本对距离;
计算负样本集B中每张所述图像与所述固定图像a之间的负样本对距离。
5.根据权利要求2所述的模型训练方法,其特征在于,所述第三损失值通过预设的第三损失函数计算得到,所述第三损失函数为:
Figure FDA0003466489350000021
其中,所述N为所述样本集中行人标签的类别总数,所述qi为预测标签,所述pi为图像中的行人属于第i类行人标签的预测概率,所述y为图像中的行人的真实标签。
6.根据权利要求1-5任意一项所述的模型训练方法,其特征在于,所述基于待训练的重识别模型分别对所述第一图像和所述第二图像进行行人重识别得到重识别预测结果,包括:
基于所述主干网络分别对所述第一图像和所述第二图像进行特征提取,得到第一特征和第二特征;
基于所述平均池化层将所述第一特征和所述第二特征向量化,得到第一特征向量和第二特征向量;
基于所述全连接层、第一特征向量以及第二特征向量进行行人重识别,得到重识别预测结果。
7.根据权利要求1所述的模型训练方法,其特征在于,在所述得到训练完成的重识别模型之后,所述模型训练方法还包括:
将样本集、待训练的重识别模型和/或训练完成的重识别模型上传至区块链中。
8.一种模型训练装置,其特征在于,所述模型训练装置包括:
数据增广模块,用于针对样本集中的每个样本图像,对所述样本图像执行两种不同的数据增广操作,得到第一图像和第二图像;
重识别模块,用于基于待训练的重识别模型分别对所述第一图像和所述第二图像进行行人重识别得到重识别预测结果,所述重识别模型包括主干网络、平均池化层以及全连接层;
计算模块,用于基于所述重识别预测结果计算所述待训练的重识别模型的第一损失值、第二损失值以及第三损失值,所述第一损失值用于描述图像所属标签的真实概率和预测概率之间的损失,所述第二损失值用于描述图像的真实区别特征与提取区别特征之间的损失,所述第三损失值用于描述图像的真实标签与预测标签之间的损失;
优化模块,用于基于所述第一损失值、所述第二损失值、所述第三损失值及预设的目标损失函数计算得到所述待训练的重识别模型的总损失值,并基于所述总损失值对所述待训练的重识别模型进行优化,直至所述目标损失函数收敛,得到训练完成的重识别模型。
9.一种终端设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述方法的步骤。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述方法的步骤。
CN202210031145.2A 2022-01-12 2022-01-12 模型训练方法、模型训练装置、终端设备及存储介质 Pending CN114358205A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210031145.2A CN114358205A (zh) 2022-01-12 2022-01-12 模型训练方法、模型训练装置、终端设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210031145.2A CN114358205A (zh) 2022-01-12 2022-01-12 模型训练方法、模型训练装置、终端设备及存储介质

Publications (1)

Publication Number Publication Date
CN114358205A true CN114358205A (zh) 2022-04-15

Family

ID=81109000

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210031145.2A Pending CN114358205A (zh) 2022-01-12 2022-01-12 模型训练方法、模型训练装置、终端设备及存储介质

Country Status (1)

Country Link
CN (1) CN114358205A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114581838A (zh) * 2022-04-26 2022-06-03 阿里巴巴达摩院(杭州)科技有限公司 图像处理方法、装置和云设备
CN117649683A (zh) * 2024-01-30 2024-03-05 深圳市宗匠科技有限公司 一种痤疮分级方法、装置、设备及存储介质

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110414432A (zh) * 2019-07-29 2019-11-05 腾讯科技(深圳)有限公司 对象识别模型的训练方法、对象识别方法及相应的装置
CN111178115A (zh) * 2018-11-12 2020-05-19 北京深醒科技有限公司 对象识别网络的训练方法及系统
CN112241664A (zh) * 2019-07-18 2021-01-19 顺丰科技有限公司 人脸识别方法、装置、服务器及存储介质
CN112801235A (zh) * 2021-04-12 2021-05-14 四川大学 模型训练方法、预测方法、装置、重识别模型及电子设备
CN113688757A (zh) * 2021-08-30 2021-11-23 五邑大学 一种sar图像识别方法、装置及存储介质
CN113723236A (zh) * 2021-08-17 2021-11-30 广东工业大学 一种结合局部阈值二值化图像的跨模态行人重识别方法

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111178115A (zh) * 2018-11-12 2020-05-19 北京深醒科技有限公司 对象识别网络的训练方法及系统
CN112241664A (zh) * 2019-07-18 2021-01-19 顺丰科技有限公司 人脸识别方法、装置、服务器及存储介质
CN110414432A (zh) * 2019-07-29 2019-11-05 腾讯科技(深圳)有限公司 对象识别模型的训练方法、对象识别方法及相应的装置
CN112801235A (zh) * 2021-04-12 2021-05-14 四川大学 模型训练方法、预测方法、装置、重识别模型及电子设备
CN113723236A (zh) * 2021-08-17 2021-11-30 广东工业大学 一种结合局部阈值二值化图像的跨模态行人重识别方法
CN113688757A (zh) * 2021-08-30 2021-11-23 五邑大学 一种sar图像识别方法、装置及存储介质

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
刘树春等: "深度实践OCR:基于深度学习的文本识别", vol. 1, 30 April 2020, 机械工业出版社, pages: 123 *

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114581838A (zh) * 2022-04-26 2022-06-03 阿里巴巴达摩院(杭州)科技有限公司 图像处理方法、装置和云设备
CN114581838B (zh) * 2022-04-26 2022-08-26 阿里巴巴达摩院(杭州)科技有限公司 图像处理方法、装置和云设备
CN117649683A (zh) * 2024-01-30 2024-03-05 深圳市宗匠科技有限公司 一种痤疮分级方法、装置、设备及存储介质
CN117649683B (zh) * 2024-01-30 2024-04-09 深圳市宗匠科技有限公司 一种痤疮分级方法、装置、设备及存储介质

Similar Documents

Publication Publication Date Title
CN110414432B (zh) 对象识别模型的训练方法、对象识别方法及相应的装置
CN111797893B (zh) 一种神经网络的训练方法、图像分类系统及相关设备
Zhang et al. Action recognition in still images with minimum annotation efforts
Hasani et al. Spatio-temporal facial expression recognition using convolutional neural networks and conditional random fields
Zhang et al. Loop closure detection for visual SLAM systems using convolutional neural network
Sun et al. A robust approach for text detection from natural scene images
Baró et al. Traffic sign recognition using evolutionary adaboost detection and forest-ECOC classification
Wu et al. Discriminative deep face shape model for facial point detection
WO2015192263A1 (en) A method and a system for face verification
Xia et al. Loop closure detection for visual SLAM using PCANet features
Taheri et al. Animal classification using facial images with score‐level fusion
WO2014205231A1 (en) Deep learning framework for generic object detection
Wang et al. Two-stage method based on triplet margin loss for pig face recognition
CN111709311A (zh) 一种基于多尺度卷积特征融合的行人重识别方法
Ilonen et al. Image feature localization by multiple hypothesis testing of Gabor features
CN114358205A (zh) 模型训练方法、模型训练装置、终端设备及存储介质
Roy et al. Deep metric and hash-code learning for content-based retrieval of remote sensing images
CN112861695A (zh) 行人身份再识别方法、装置、电子设备及存储介质
Wu et al. Improving pedestrian detection with selective gradient self-similarity feature
Wang et al. Online visual place recognition via saliency re-identification
Hu et al. Adversarial binary mutual learning for semi-supervised deep hashing
Giraddi et al. Flower classification using deep learning models
Najibi et al. Towards the success rate of one: Real-time unconstrained salient object detection
CN108496174B (zh) 用于面部识别的方法和系统
CN114092873A (zh) 一种基于外观与形态解耦的长时期跨摄像头目标关联方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination