CN116663648B - 模型训练方法、装置、设备及存储介质 - Google Patents

模型训练方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN116663648B
CN116663648B CN202310444646.8A CN202310444646A CN116663648B CN 116663648 B CN116663648 B CN 116663648B CN 202310444646 A CN202310444646 A CN 202310444646A CN 116663648 B CN116663648 B CN 116663648B
Authority
CN
China
Prior art keywords
similarity
group
image
features
feature
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202310444646.8A
Other languages
English (en)
Other versions
CN116663648A (zh
Inventor
曾炜
陈建平
袁孝宇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Peking University
Original Assignee
Peking University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Peking University filed Critical Peking University
Priority to CN202310444646.8A priority Critical patent/CN116663648B/zh
Publication of CN116663648A publication Critical patent/CN116663648A/zh
Application granted granted Critical
Publication of CN116663648B publication Critical patent/CN116663648B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • G06N3/0455Auto-encoder networks; Encoder-decoder networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • G06V10/44Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
    • G06V10/443Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components by matching or filtering
    • G06V10/449Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • G06V10/50Extraction of image or video features by performing operations within image blocks; by using histograms, e.g. histogram of oriented gradients [HoG]; by summing image-intensity values; Projection analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/74Image or video pattern matching; Proximity measures in feature spaces
    • G06V10/761Proximity, similarity or dissimilarity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/80Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
    • G06V10/806Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Biomedical Technology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Databases & Information Systems (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Biophysics (AREA)
  • Data Mining & Analysis (AREA)
  • Biodiversity & Conservation Biology (AREA)
  • Image Analysis (AREA)
  • Manipulator (AREA)

Abstract

本申请涉及一种模型训练方法、装置、设备及存储介质,该方法在神经网络模型训练的过程中,引入第一特征空间,并获取当前训练图像在第一特征空间中的第一类特征,进一步计算第一类特征与第一特征组中各特征的相似度,得到第一相似度组,以第一相似度组校正与神经网络模型的第二特征空间相关的第二相似度组,使得校正后的校正相似度组关注到与当前训练图像属于相同语义类别但相似度较低的样本,从而缓解了不可靠样本带来的错误指导。

Description

模型训练方法、装置、设备及存储介质
技术领域
本申请涉及计算机领域,尤其涉及一种模型训练方法、装置、设备及存储介质。
背景技术
自监督表示学习是近年来的研究热点,主要是通过解决精心设计的代理任务来从大规模无标注数据中学习到通用的特征表示,并以此作为大量下游任务的初始化。
当前,基于对比学习的实例判别方法在自监督表示学习中表现出了巨大的潜力,在各个下游任务中已经接近甚至超过有监督预训练。它的具体实现是在特征空间中拉近利用数据增强技术得到的正样本同时把其他样本作为负样本推远。但是,实例判别方法的负样本中不可避免地存在与当前样本属于相同语义类别的样本,即假阴性(FalseNegative)样本,这导致学习到的语义结构受限。
发明内容
本申请提供了一种模型训练方法、装置、设备及存储介质,用以解决模型训练过程中,假阴性样本导致学习到的语义结构受限,给模型训练带来错误指导的问题。
第一方面,提供一种模型训练方法,包括:
获取当前训练图像的第一增强视图在第一特征空间的第一类特征,所述第一特征空间用于描述所述当前训练图像需要被提取的特征的维度;
计算所述第一类特征与第一特征组中各特征的相似度,得到第一相似度组;所述第一特征组中的特征为训练数据集中训练图像的第一类特征;所述训练数据集包括所述当前训练图像;
采用所述第一相似度组校正第二相似度组,得到校正相似度组;所述第二相似度组为所述第一增强视图的第二类特征与第二特征组中各特征的相似度的集合;所述第二特征组中的特征为所述训练数据集中训练图像的第二类特征;所述第一增强视图的第二类特征为所述第一增强视图在第二特征空间的图像特征;所述第二特征空间中的维度与所述第一特征空间的维度不完全相同;
利用所述校正相似度组优化神经网络模型的参数,所述神经网络模型用于提取所述第二类特征。
可选地,利用所述校正相似度组优化所述第二特征空间对应的神经网络模型的参数,包括:
利用所述校正相似度组、第三相似度组和预设的损失函数,计算模型损失;所述第三相似度组为第二增强视图的第二类特征与所述第二特征组中各特征的相似度的集合;所述第二增强视图为所述当前训练图像的另一视图;
采用所述模型损失优化所述神经网络模型的参数。
可选地,采用所述第一相似度组校正第二相似度组,得到校正相似度组,包括:
获取所述第一相似度组中最大的K个相似度;
对于所述K个图像相似度中的任一图像相似度,获取所述任一图像相似度在所述第一相似度组中的目标位置;
采用所述任一图像相似度更新所述第二相似度组中处于所述目标位置的图像相似度,得到所述校正相似度组。
可选地,采用所述任一图像相似度更新所述第二相似度组中处于所述目标位置的图像相似度,以得到所述校正相似度组,包括:
获取所述K个图像相似度中的最大图像相似度;
计算所述任一图像相似度与所述最大图像相似度的商值;
采用所述商值替换所述第二相似度组中处于所述目标位置的图像相似度,得到所述校正相似度组。
可选地,采用所述商值替换所述第二相似度组中处于所述目标位置的图像相似度,以得到所述校正相似度组之前,还包括:
确定所述商值和处于所述目标位置的图像相似度属于归一化取值。
可选地,获取当前训练图像的第一增强视图在第一特征空间的第一类特征之前,还包括:
确定所述神经网络模型的迭代轮次未超过预设轮次。
可选地,获取当前训练图像的第一增强视图在第一特征空间的第一类特征,包括:
获取所述第一增强视图的方向梯度直方图特征;
将所述方向梯度直方图特征作为所述第一类特征。
第二方面,提供一种模型训练装置,包括:
获取模块,用于获取当前训练图像的第一增强视图在第一特征空间的第一类特征,所述第一特征空间用于描述所述当前训练图像需要被提取的特征的维度;
计算模块,用于计算所述第一类特征与第一特征组中各特征的相似度,得到第一相似度组;所述第一特征组中的特征为训练数据集中训练图像的第一类特征;所述训练数据集包括所述当前训练图像;
校正模块,用于采用所述第一相似度组校正第二相似度组,得到校正相似度组;所述第二相似度组为所述第一增强视图的第二类特征与第二特征组中各特征的相似度的集合;所述第二特征组中的特征为所述训练数据集中训练图像的第二类特征;所述第一增强视图的第二类特征为所述第一增强视图在第二特征空间的图像特征;所述第二特征空间中的维度与所述第一特征空间的维度不完全相同;
优化模块,用于利用所述校正相似度组优化神经网络模型的参数,所述神经网络模型用于提取所述第二类特征。
第三方面,提供一种电子设备,其特征在于,包括:处理器、存储器和通信总线,其中,处理器和存储器通过通信总线完成相互间的通信;
存储器,用于存储计算机程序;
处理器,用于执行存储器中所存储的程序,实现第一方面所述的模型训练方法。
第四方面,提供一种计算机可读存储介质,存储有计算机程序,其特征在于,计算机程序被处理器执行时实现第一方面所述的模型训练方法。
本申请实施例提供的上述技术方案与现有技术相比具有如下优点:本申请实施例提供的该方法,在神经网络模型训练的过程中,引入第一特征空间,并获取当前训练图像在第一特征空间中的第一类特征,进一步计算第一类特征与第一特征组中各特征的相似度,得到第一相似度组,以第一相似度组校正与神经网络模型相关的第二相似度组,使得校正后的校正相似度组关注到与当前训练图像属于相同语义类别但相似度较低的样本,从而缓解了不可靠样本带来的错误指导。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本发明的实施例,并与说明书一起用于解释本发明的原理。
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,对于本领域普通技术人员而言,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为传统的基于样本关联的方法中目标分布存在不可靠关联的场景示意图;
图2为本实施例中模型训练方法的流程示意图;
图3为本实施例中神经网络模型的结构示意图;
图4为本申请实施例中利用改进后的CSAC在特征空间中的结果示意图;
图5为本实施例中模型训练装置的结构示意图;
图6为本实施例中电子设备的结构示意图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请的一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本申请保护的范围。
需要说明的是,本申请的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
在传统的基于样本关联的方法中,目标分布不可避免地存在不可靠的样本关联,即不同语义类别样本的相似度大于相同语义类别样本的相似度。以图1为例,目标图片为一张马的图片,即图1中左侧包含马的图片,需要与目标图片计算相似度的图片为图1中右侧分别包含狗、马、鸟、猫的四张图片。在传统的基于样本关联的方法得到的目标分布中,往往会出现目标图片与图1中右侧包含狗、鸟或猫的图片的相似度大于目标图片与图1中右侧包含马的图片的相似度,这种情况会导致目标分布带来错误指导。
为了解决上述问题,本申请实施例提供一种模型训练方法,该方法可应用于电子设备中,该电子设备可以包括终端设备或者服务器,本申请实施例不做限定。该终端设备可以是诸如手机、平板电脑、笔记本电脑、掌上电脑、PAD(Personal Digital Assistant,个人数字助理)、等等的移动终端以及诸如数字TV、台式计算机等等的固定终端。
如图2所示,该方法可包括以下步骤:
步骤201、获取当前训练图像的第一增强视图在第一特征空间的第一类特征,第一特征空间用于描述当前训练图像需要被提取的特征的维度。
本实施例中,可以通过多种数据增强的组合方式得到第一增强视图,如通过对当前训练图像进行随机裁剪、翻动和/或抖动得到第一增强视图,本实施例对此不作具体限定。
本实施例中,第一类特征可以为方向梯度直方图(Histogram ofOrientedGradient,HOG)特征或Gabor特征等,本实施例对此不做具体限定。
其中,HOG特征为通过计算和统计图像局部区域的梯度方向直方图来构成的特征。通常在一副图像中,局部目标的表象和形状(appearance and shape)能够被梯度或边缘的方向密度分布很好地描述。具体获取HOG特征时,首先将图像分成小的连通区域,得到多个细胞单元,然后采集细胞单元中各像素点的梯度的或边缘的方向直方图,最后把这些直方图组合起来就可得到图像的HOG特征。
本实施例中,在采用HOG特征作为第一类特征时,由于HOG特征判别能力有限,因此为了保证后续的校正效果同时还能降低计算量,只在神经网络模型的训练早期应用HOG特征进行校正。
具体实现时,一个可选实施例中,获取当前训练图像的第一增强视图在第一特征空间的第一类特征之前,获取神经网络模型的迭代轮次(epoch),并在确定该迭代轮次未超过预设轮次时,才获取第一增强视图的第一特征空间的第一类特征。
其中,一个epoch表示训练数据集中的训练图像都送入过神经网络模型,并且每张训练图像都完成了一次前向计算和反向传播的过程。
其中,预设轮次可以人为基于经验或实际需要设置,本实施例对此不作具体限定。
本实施例中,当迭代轮次超过预设轮次时,不再获取第一增强视图的第一类特征,并且按照现有技术中对神经网络模型的训练过程执行,这里不再展开描述。
步骤202、计算第一类特征与第一特征组中各特征的相似度,得到第一相似度组;第一特征组中的特征为训练数据集中训练图像的第一类特征;训练数据集包括当前训练图像。
本实施例中,第一特征组中可以包括多个特征,每个特征均属于训练数据集中训练图像的第一类特征所构成的特征集合。第一特征组中可以包括训练数据集中每张训练图像的第一类特征,当然为了降低计算量,也可以仅包括训练数据集中部分训练图像的第一类特征。本实施例对此不做具体限定。
其中,当第一特征组中仅包括训练数据集中部分训练图像的第一类特征时,为了尽量提高计算准确度,设置第一特征组中第一类特征随着当前训练图像的改变而动态变化。具体说来,第一特征组以队列的形式实现,队列遵循先进先出的原则,在获取当前训练图像的第一类特征后,将当前训练图像的第一类特征插入第一特征组对应的队列,相应地,该队列在当前训练图像的第一类特征插入之前的队尾元素则出队。比如,在当前训练图像的第一类特征插入之前,队列中的元素按由队首到队尾分别为a、b、c,那么将当前训练图像的第一类特征d插入队列后,队列中的元素由队首到队尾分别为d、a、b。
实际应用中,第一特征组以队列形式实现时,该队列中各元素的初始值为随机数值,并不为训练数据集中训练图像的第一类特征。随着训练的进行,队列中各元素逐步被训练数据集中训练图像的第一特征代替并动态更新。
一个可选实施例中,通过计算余弦相似度得到第一相似度组。具体说来,计算第一类特征与第一特征组中各特征的余弦相似度,并将余弦相似度构成的集合作为第一相似度组。应当理解的是,第一相似度组中的特定位置的相似度,表示当前训练图像的第一类特征与第一特征组中特定位置的特征之间的相似度。即第一相似度组与第一特征组通过位置具有一一对应的关系。
步骤203、采用第一相似度组校正第二相似度组,得到校正相似度组;第二相似度组为第一增强视图的第二类特征与第二特征组中各特征的相似度的集合;第二特征组中的特征为训练数据集中训练图像的第二类特征;第一增强视图的第二类特征为第一增强视图在第二特征空间的图像特征;第二特征空间中的维度与第一特征空间的维度不完全相同。
应理解,第二特征空间的维度与第一特征空间的维度不完全相同,表示第二特征空间的维度与第一特征空间的的维度完全不同或,第二特征空间的维度与第一特征空间的维度存在部分相同。
应理解,这里也可以采用余弦相似度计算得到第二特征组,具体实现过程详见前文采用余弦相似度计算第一特征组的介绍,这里不再赘述。
本实施例中,可以采用加权平均、最大值替换等方式实现第一相似度组对第二相似度组的校正,本实施例对此不做具体限定。
一个可选实施例中,采用加权平均实现第一相似度组对第二相似度组的校正时,对第一相似度组和第二相似度组对应位置的相似度求和,然后对求和结果求平均值,并将平均值作为校正相似度组在对应位置的相似度。
一个可选实施例中,采用最大值替换实现第一相似度组对第二相似度组的校正时,获取第一相似度组中最大的K个相似度;对于K个图像相似度中的任一图像相似度,获取任一图像相似度在第一相似度组中的目标位置;采用任一图像相似度更新第二相似度组中处于目标位置的图像相似度,得到校正相似度组。
本实施例中,第一相似度组中最大的K个相似度指的是,在按照从大到小的方式对第一相似度组中的相似度进行排序的情况下,第一相似度组中的前K个相似度。比如,第一相似度组中具有0.3、0.5、0.4、0.7四个相似度,那么第一相似度组中最大的两个相似度则为0.5和0.7。
本实施例中,在采用任一图像相似度更新第二相似度组中处于目标位置的图像相似度时,可以以任一图像相似度直接替换第二相似度组中处于目标位置的图像相似度,得到校正相似度组。
本实施例中,为了使得校正之后的校正相似度组可以着重关注到与当前训练图像属于相同语义类别但相似度较低的样本,在采用任一图像相似度更新第二相似度组中处于目标位置的图像相似度时,获取K个图像相似度中的最大图像相似度;计算任一图像相似度与最大图像相似度的商值;采用商值替换第二相似度组中处于目标位置的图像相似度,得到校正相似度组。
举例说来,K个图像相似度分别为0.4、0.3、0.5,那么对第二相似度组进行的替换的K个图像相似度实际为0.4/0.5=0.8、0.3/0.5=0.6、0.5/0.5=1。
本实施例中,由于商值和处于目标位置的图像相似度有可能表示意义相反的量,比如商值为正数,而处于目标位置的图像相似度为负数,在这种情况下如果直接进行替换,则可能影响校正的效果,因此在采用商值替换第二相似度组中处于目标位置的图像相似度,得到校正相似度组之前,确定商值和处于目标位置的图像相似度属于归一化取值。
本实施例中,使商值和处于目标位置的图像相似度属于归一化取值可以有如下两种实现:
直接对商值和处于目标位置的图像相似度进行归一化操作;或,
分别对第一相似度组和第二相似度组中的相似度进行归一化操作。
本实施例中,第二特征组可以包括多个特征,每个特征均属于训练数据集中训练图像的第二类特征所构成的特征集合。第二特征组中可以包括训练数据集中每张训练图像的第二类特征,当然为了降低计算量,也可以仅包括训练数据集中部分训练图像的第二类特征。本实施例对此不做具体限定。
其中,当第二特征组中仅包括训练数据集中部分训练图像的第二类特征时,为了尽量提高计算准确度,设置第二特征组中第二类特征随着当前训练图像的改变而动态变化。具体说来,第二特征组以队列的形式实现,队列遵循先进先出的原则,在获取当前训练图像的第二类特征后,将当前训练图像的第二类特征插入第二特征组对应的队列,相应地,该队列在当前训练图像的第二类特征插入之前的队尾元素则出队。实际应用中,第二特征组以队列形式实现时,该队列中各元素的初始值为随机数值,并不为训练数据集中训练图像的第二类特征。随着训练的进行,队列中各元素逐步被训练数据集中训练图像的第二特征代替并动态更新。
步骤204、利用校正相似度组优化神经网络模型的参数,神经网络模型用于提取第二类特征。
本实施例中的神经网络模型包括但不限于自监督表示学习模型,比如基于跨空间样本关联校正(CSAC,Cross-space Sample AssociationCorrection)的自监督表示学习模型。
本实施例中,在利用校正相似度组进行神经网络模型参数的优化时,借助预设的损失函数计算模型损失,并通过反向传播该模型损失实现对神经网络模型的参数的优化。
具体说来,一个可选实施例中,利用校正相似度组、第三相似度组和预设的损失函数,计算模型损失;第三相似度组为第二增强视图的第二类特征与第二特征组中各特征的相似度的集合;第二增强视图为当前训练图像的另一视图;采用模型损失优化神经网络模型的参数。
本实施例中,预设的损失函数包括但不限于交叉熵损失函数。
本实施例提供的技术方案中,在神经网络模型训练的过程中,引入第一特征空间,并获取当前训练图像在第一特征空间中的第一类特征,进一步计算第一类特征与第一特征组中各特征的相似度,得到第一相似度组,以第一相似度组校正与神经网络模型的第二特征空间相关的第二相似度组,使得校正后的校正相似度组关注到与当前训练图像属于相同语义类别但相似度较低的样本,从而缓解了不可靠样本带来的错误指导。
以下以神经网络模型为基于CSAC的自监督表示学习模型,并引入计算HOG特征的分支为例,介绍对模型的训练过程。请参照图3,图3为模型的结构图。
整个模型由四部分组成:特征提取模块,相似度分布计算模块,分布校正模块,分布对齐模块。
特征提取模块包括深度特征提取和HOG特征提取。深度特征分别在编码器(如ResNet)和其对应的动量更新编码器后接MLP层来得到两个不同数据增强视图的特征。HOG特征提取直接对图像分块计算局部梯度,然后按顺序拼接起来作为特征向量。
相似度分布计算模块分别计算在线分布,目标分布和辅助分布。在线分布和目标分布分别由编码器和动量更新编码器得到的特征与同一个队列计算余弦相似度,然后进行Softmax操作。辅助分布是由HOG特征与其对应的队列计算余弦相似度得到。
分布校正模块是利用辅助分布进行排序得到前k个相似度最大的位置,将得到的k个相似度进行Softmax操作然后除以其最大值。然后将得到的数值按照在辅助分布中的位置对应赋值到目标分布中的相应位置。以此作为校正后的目标分布。由于HOG特征判别能力有限,因此只在训练早期应用HOG特征进行校正。
分布对齐模块是利用交叉熵损失函数将在线分布向目标分布对齐,其中在线分布所在分支进行梯度回传,目标分布所在分支无梯度回传。为防止模型坍塌,目标分布用更小的温度系数来锐化样本的相对关系。
模型的训练流程为:
当前训练图像经过多种数据增强的组合(如随机裁剪、翻转、抖动)得到两个不同视图。
将数据增强后的视图经过深度神经网络和HOG计算分支进行特征提取。
提取后的视觉特征分别与其对应的队列计算图2中的在线分布,目标分布和辅助分布。
在前m个epoch中利用辅助分布中相似度最大的k个位置校正目标分布中的对应位置。
利用交叉熵损失优化在线分布向校正后的目标分布对齐。
当前批处理特征进队列,队列中的旧样本相应出队列,以此得到更新后的队列。
在m个epoch之后,得到的目标分布不再进行校正,直接进行分布对齐。
训练n个epoch之后,得到的模型作为预训练模型进行评估。
为了评估改进后的CSAC的效果,使改进后的CSAC在Cifar-10和ImageNet-100(简写为IN-100)两个数据集上进行训练和测试。其中Cifar-10数据集包括50000张训练图像,10000张测试图像,共10个类别,图像大小为32*32,编码器为ResNet-18结构。IN-100数据集包括126689张训练图像和5000张测试图像,共100个类别,图像大小为224*224,编码器为ResNet-50结构。在Cifar-10训练200个epoch,在IN-100训练100个epoch。线性分类器性能作为下游任务评价指标。
表格1展示了改进后的CSAC在Cifar-10和ImageNet-100两个数据集上与MoCo-v2和分布对齐基准在top-1和top-5线性分类准确率下的对比。实验结果表明改进后的CSAC显著提高了分布对齐基准的性能,其中在Cifar-10数据集上可以提升2.6%的top-1准确率,在ImageNet-100上可以提升1.6%的top-1准确率。该实验结果证明了改进后的CSAC可以缓解目标分布中不可靠关联的错误指导问题,因此可以得到更高的准确率。
表格2展示了改进后的CSAC在Cifar-10数据集上进行分布校正时所需的参数对比:k近邻中的k值和校正epoch数。在Cifar-10数据集上利用辅助分布的10近邻在训练的前80个epoch校正目标分布可以得到最好的性能88.0%。
表格1
表格2
图4中利用t-SNE技术将高维特征空间嵌入到2维空间,相同颜色的点属于相同的语义类别。经过对比,改进后的CSAC在特征空间中更具有判别性。
基于同一构思,本申请实施例中提供了一种模型训练装置,该装置的具体实施可参见方法实施例部分的描述,重复之处不再赘述,如图5所示,该装置主要包括:
获取模块501,用于获取当前训练图像的第一增强视图在第一特征空间的第一类特征,第一特征空间用于描述当前训练图像需要被提取的特征的维度;
计算模块502,用于计算第一类特征与第一特征组中各特征的相似度,得到第一相似度组;第一特征组中的特征为训练数据集中训练图像的第一类特征;训练数据集包括当前训练图像;
校正模块503,用于采用第一相似度组校正第二相似度组,得到校正相似度组;第二相似度组为第一增强视图的第二类特征与第二特征组中各特征的相似度的集合;第二特征组中的特征为训练数据集中训练图像的第二类特征;第一增强视图的第二类特征为第一增强视图在第二特征空间的图像特征;第二特征空间中的维度与第一特征空间的维度不完全相同;
优化模块504,用于利用校正相似度组优化神经网络模型的参数,神经网络模型用于提取第二类特征。
可选地,优化模块504用于:
利用校正相似度组、第三相似度组和预设的损失函数,计算模型损失;第三相似度组为第二增强视图的第二类特征与第二特征组中各特征的相似度的集合;第二增强视图为当前训练图像的另一视图;
采用模型损失优化神经网络模型的参数。
可选地,优化模块504用于:
获取第一相似度组中最大的K个相似度;
对于K个图像相似度中的任一图像相似度,获取任一图像相似度在第一相似度组中的目标位置;
采用任一图像相似度更新第二相似度组中处于目标位置的图像相似度,得到校正相似度组。
可选地,优化模块504用于:
获取K个图像相似度中的最大图像相似度;
计算任一图像相似度与最大图像相似度的商值;
采用商值替换第二相似度组中处于目标位置的图像相似度,得到校正相似度组。
可选地,该装置还用于:
采用商值替换第二相似度组中处于目标位置的图像相似度,以得到校正相似度组之前,还确定商值和处于目标位置的图像相似度属于归一化取值。
可选地,该装置还用于:
获取当前训练图像的第一增强视图在第一特征空间的第一类特征之前,还确定神经网络模型的迭代轮次未超过预设轮次。
可选地,获取模块501用于:
获取第一增强视图的方向梯度直方图特征;
将方向梯度直方图特征作为第一类特征。
基于同一构思,本申请实施例中还提供了一种电子设备,如图6所示,该电子设备主要包括:处理器601、存储器602和通信总线603,其中,处理器601和存储器602通过通信总线603完成相互间的通信。其中,存储器602中存储有可被处理器601执行的程序,处理器601执行存储器602中存储的程序,实现如下步骤:
获取当前训练图像的第一增强视图在第一特征空间的第一类特征,第一特征空间用于描述当前训练图像需要被提取的特征的维度;
计算第一类特征与第一特征组中各特征的相似度,得到第一相似度组;第一特征组中的特征为训练数据集中训练图像的第一类特征;训练数据集包括当前训练图像;
采用第一相似度组校正第二相似度组,得到校正相似度组;第二相似度组为第一增强视图的第二类特征与第二特征组中各特征的相似度的集合;第二特征组中的特征为训练数据集中训练图像的第二类特征;第一增强视图的第二类特征为第一增强视图在第二特征空间的图像特征;第二特征空间中的维度与第一特征空间的维度不完全相同;
利用校正相似度组优化神经网络模型的参数,神经网络模型用于提取第二类特征。
上述电子设备中提到的通信总线603可以是外设部件互连标准(PeripheralComponent Interconnect,简称PCI)总线或扩展工业标准结构(Extended IndustryStandard Architecture,简称EISA)总线等。该通信总线603可以分为地址总线、数据总线、控制总线等。为便于表示,图6中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
存储器602可以包括随机存取存储器(Random Access Memory,简称RAM),也可以包括非易失性存储器(non-volatile memory),例如至少一个磁盘存储器。可选地,存储器还可以是至少一个位于远离前述处理器601的存储装置。
上述的处理器601可以是通用处理器,包括中央处理器(CentralProcessingUnit,简称CPU)、网络处理器(Network Processor,简称NP)等,还可以是数字信号处理器(Digital Signal Processing,简称DSP)、专用集成电路(Application SpecificIntegrated Circuit,简称ASIC)、现场可编程门阵列(Field-Programmable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
在本申请的又一实施例中,还提供了一种计算机可读存储介质,该计算机可读存储介质中存储有计算机程序,当该计算机程序在计算机上运行时,使得计算机执行上述实施例中所描述的模型训练方法。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。该计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行该计算机指令时,全部或部分地产生按照本申请实施例的流程或功能。该计算机可以时通用计算机、专用计算机、计算机网络或者其他可编程装置。该计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,计算机指令从一个网站站点、计算机、服务器或者数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、微波等)方式向另外一个网站站点、计算机、服务器或数据中心进行传输。该计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。该可用介质可以是磁性介质(例如软盘、硬盘、磁带等)、光介质(例如DVD)或者半导体介质(例如固态硬盘)等。
需要说明的是,在本文中,诸如“第一”和“第二”等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括要素的过程、方法、物品或者设备中还存在另外的相同要素。
以上仅是本发明的具体实施方式,使本领域技术人员能够理解或实现本发明。对这些实施例的多种修改对本领域的技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所申请的原理和新颖特点相一致的最宽的范围。

Claims (8)

1.一种模型训练方法,其特征在于,包括:
获取当前训练图像的第一增强视图在第一特征空间的第一类特征,所述第一特征空间用于描述所述当前训练图像需要被提取的特征的维度;
计算所述第一类特征与第一特征组中各特征的相似度,得到第一相似度组;所述第一特征组中的特征为训练数据集中训练图像的第一类特征;所述训练数据集包括所述当前训练图像;
采用所述第一相似度组校正第二相似度组,得到校正相似度组;所述第二相似度组为所述第一增强视图的第二类特征与第二特征组中各特征的相似度的集合;所述第二特征组中的特征为所述训练数据集中训练图像的第二类特征;所述第一增强视图的第二类特征为所述第一增强视图在第二特征空间的图像特征;所述第二特征空间中的维度与所述第一特征空间的维度不完全相同;
利用所述校正相似度组优化神经网络模型的参数,所述神经网络模型用于提取所述第二类特征;
采用所述第一相似度组校正第二相似度组,得到校正相似度组,包括:
获取所述第一相似度组中最大的K个图像相似度;
对于所述K个图像相似度中的任一图像相似度,获取所述任一图像相似度在所述第一相似度组中的目标位置;
采用所述任一图像相似度更新所述第二相似度组中处于所述目标位置的图像相似度,得到所述校正相似度组;
采用所述任一图像相似度更新所述第二相似度组中处于所述目标位置的图像相似度,以得到所述校正相似度组,包括:
获取所述K个图像相似度中的最大图像相似度;
计算所述任一图像相似度与所述最大图像相似度的商值;
采用所述商值替换所述第二相似度组中处于所述目标位置的图像相似度,得到所述校正相似度组。
2.根据权利要求1所述的方法,其特征在于,利用所述校正相似度组优化所述第二特征空间对应的神经网络模型的参数,包括:
利用所述校正相似度组、第三相似度组和预设的损失函数,计算模型损失;所述第三相似度组为第二增强视图的第二类特征与所述第二特征组中各特征的相似度的集合;所述第二增强视图为所述当前训练图像的另一视图;
采用所述模型损失优化所述神经网络模型的参数。
3.根据权利要求1所述的方法,其特征在于,采用所述商值替换所述第二相似度组中处于所述目标位置的图像相似度,以得到所述校正相似度组之前,还包括:
确定所述商值和处于所述目标位置的图像相似度属于归一化取值。
4.根据权利要求1所述的方法,其特征在于,获取当前训练图像的第一增强视图在第一特征空间的第一类特征之前,还包括:
确定所述神经网络模型的迭代轮次未超过预设轮次。
5.根据权利要求1所述的方法,其特征在于,获取当前训练图像的第一增强视图在第一特征空间的第一类特征,包括:
获取所述第一增强视图的方向梯度直方图特征;
将所述方向梯度直方图特征作为所述第一类特征。
6.一种模型训练装置,其特征在于,包括:
获取模块,用于获取当前训练图像的第一增强视图在第一特征空间的第一类特征,所述第一特征空间用于描述所述当前训练图像需要被提取的特征的维度;
计算模块,用于计算所述第一类特征与第一特征组中各特征的相似度,得到第一相似度组;所述第一特征组中的特征为训练数据集中训练图像的第一类特征;所述训练数据集包括所述当前训练图像;
校正模块,用于采用所述第一相似度组校正第二相似度组,得到校正相似度组;所述第二相似度组为所述第一增强视图的第二类特征与第二特征组中各特征的相似度的集合;所述第二特征组中的特征为所述训练数据集中训练图像的第二类特征;所述第一增强视图的第二类特征为所述第一增强视图在第二特征空间的图像特征;所述第二特征空间中的维度与所述第一特征空间的维度不完全相同;
优化模块,用于利用所述校正相似度组优化神经网络模型的参数,所述神经网络模型用于提取所述第二类特征;
校正模块用于:
获取所述第一相似度组中最大的K个图像相似度;
对于所述K个图像相似度中的任一图像相似度,获取所述任一图像相似度在所述第一相似度组中的目标位置;
采用所述任一图像相似度更新所述第二相似度组中处于所述目标位置的图像相似度,得到所述校正相似度组;
校正模块用于:
获取所述K个图像相似度中的最大图像相似度;
计算所述任一图像相似度与所述最大图像相似度的商值;
采用所述商值替换所述第二相似度组中处于所述目标位置的图像相似度,得到所述校正相似度组。
7.一种电子设备,其特征在于,包括:处理器、存储器和通信总线,其中,处理器和存储器通过通信总线完成相互间的通信;
存储器,用于存储计算机程序;
处理器,用于执行存储器中所存储的程序,实现权利要求1-5任一项所述的模型训练方法。
8.一种计算机可读存储介质,存储有计算机程序,其特征在于,计算机程序被处理器执行时实现权利要求1-5任一项所述的模型训练方法。
CN202310444646.8A 2023-04-23 2023-04-23 模型训练方法、装置、设备及存储介质 Active CN116663648B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310444646.8A CN116663648B (zh) 2023-04-23 2023-04-23 模型训练方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310444646.8A CN116663648B (zh) 2023-04-23 2023-04-23 模型训练方法、装置、设备及存储介质

Publications (2)

Publication Number Publication Date
CN116663648A CN116663648A (zh) 2023-08-29
CN116663648B true CN116663648B (zh) 2024-04-02

Family

ID=87721420

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310444646.8A Active CN116663648B (zh) 2023-04-23 2023-04-23 模型训练方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN116663648B (zh)

Citations (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111898682A (zh) * 2020-07-31 2020-11-06 平安科技(深圳)有限公司 基于多个源模型修正新模型的方法、装置以及计算机设备
CN112507106A (zh) * 2021-02-05 2021-03-16 恒生电子股份有限公司 深度学习模型的训练方法、装置和faq相似度判别方法
CN112508130A (zh) * 2020-12-25 2021-03-16 商汤集团有限公司 聚类方法及装置、电子设备和存储介质
WO2021164306A1 (zh) * 2020-09-17 2021-08-26 平安科技(深圳)有限公司 图像分类模型的训练方法、装置、计算机设备及存储介质
CN113378940A (zh) * 2021-06-15 2021-09-10 北京市商汤科技开发有限公司 神经网络训练方法、装置、计算机设备及存储介质
CN113536763A (zh) * 2021-07-20 2021-10-22 北京中科闻歌科技股份有限公司 一种信息处理方法、装置、设备及存储介质
CN113592911A (zh) * 2021-07-31 2021-11-02 西南电子技术研究所(中国电子科技集团公司第十研究所) 表观增强深度目标跟踪方法
CN113743535A (zh) * 2019-05-21 2021-12-03 北京市商汤科技开发有限公司 神经网络训练方法及装置以及图像处理方法及装置

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108694200B (zh) * 2017-04-10 2019-12-20 北京大学深圳研究生院 一种基于深度语义空间的跨媒体检索方法

Patent Citations (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113743535A (zh) * 2019-05-21 2021-12-03 北京市商汤科技开发有限公司 神经网络训练方法及装置以及图像处理方法及装置
CN111898682A (zh) * 2020-07-31 2020-11-06 平安科技(深圳)有限公司 基于多个源模型修正新模型的方法、装置以及计算机设备
WO2021164306A1 (zh) * 2020-09-17 2021-08-26 平安科技(深圳)有限公司 图像分类模型的训练方法、装置、计算机设备及存储介质
CN112508130A (zh) * 2020-12-25 2021-03-16 商汤集团有限公司 聚类方法及装置、电子设备和存储介质
CN112507106A (zh) * 2021-02-05 2021-03-16 恒生电子股份有限公司 深度学习模型的训练方法、装置和faq相似度判别方法
CN113378940A (zh) * 2021-06-15 2021-09-10 北京市商汤科技开发有限公司 神经网络训练方法、装置、计算机设备及存储介质
CN113536763A (zh) * 2021-07-20 2021-10-22 北京中科闻歌科技股份有限公司 一种信息处理方法、装置、设备及存储介质
CN113592911A (zh) * 2021-07-31 2021-11-02 西南电子技术研究所(中国电子科技集团公司第十研究所) 表观增强深度目标跟踪方法

Also Published As

Publication number Publication date
CN116663648A (zh) 2023-08-29

Similar Documents

Publication Publication Date Title
WO2019109743A1 (zh) Url攻击检测方法、装置以及电子设备
CN108073902B (zh) 基于深度学习的视频总结方法、装置及终端设备
WO2022033150A1 (zh) 图像识别方法、装置、电子设备及存储介质
CN109002766B (zh) 一种表情识别方法及装置
CN110532417B (zh) 基于深度哈希的图像检索方法、装置及终端设备
CN108681746B (zh) 一种图像识别方法、装置、电子设备和计算机可读介质
CN108985190B (zh) 目标识别方法和装置、电子设备、存储介质
CN107209860A (zh) 使用分块特征来优化多类图像分类
CN109165309B (zh) 负例训练样本采集方法、装置及模型训练方法、装置
CN112990318B (zh) 持续学习方法、装置、终端及存储介质
CN112115317A (zh) 一种针对深度哈希检索的有目标攻击方法及终端设备
CN110705489B (zh) 目标识别网络的训练方法、装置、计算机设备和存储介质
CN111223128A (zh) 目标跟踪方法、装置、设备及存储介质
CN111694954B (zh) 图像分类方法、装置和电子设备
CN112749737A (zh) 图像分类方法及装置、电子设备、存储介质
Li et al. Individual dairy cow identification based on lightweight convolutional neural network
WO2019085332A1 (zh) 金融数据分析方法、应用服务器及计算机可读存储介质
CN110910325B (zh) 一种基于人工蝴蝶优化算法的医疗影像处理方法及装置
CN112668718B (zh) 神经网络训练方法、装置、电子设备以及存储介质
CN113011532A (zh) 分类模型训练方法、装置、计算设备及存储介质
Fu et al. Lightweight individual cow identification based on Ghost combined with attention mechanism
CN116663648B (zh) 模型训练方法、装置、设备及存储介质
CN109657710B (zh) 数据筛选方法、装置、服务器及存储介质
CN116503670A (zh) 图像分类及模型训练方法、装置和设备、存储介质
CN114155388B (zh) 一种图像识别方法、装置、计算机设备和存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant