CN111368989A - 神经网络模型的训练方法、装置、设备及可读存储介质 - Google Patents
神经网络模型的训练方法、装置、设备及可读存储介质 Download PDFInfo
- Publication number
- CN111368989A CN111368989A CN201811592198.1A CN201811592198A CN111368989A CN 111368989 A CN111368989 A CN 111368989A CN 201811592198 A CN201811592198 A CN 201811592198A CN 111368989 A CN111368989 A CN 111368989A
- Authority
- CN
- China
- Prior art keywords
- trained
- small
- images
- sample set
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
本发明公开一种神经网络模型的训练方法、装置、设备及可读存储介质。该方法包括:获取图像训练数据集;根据图像训练数据集,创建小批量样本集;获取小批量样本集的特征集合;依次以小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在特征集合中确定其对应的难正例的特征与难负例的特征;根据每一个锚点的特征及其对应的难正例的特征与难负例的特征,确定小批量样本集的三元组损失;根据小批量样本集的三元组损失,训练神经网络模型;以及当小批量样本集的三元组损失的曲线收敛到预期状态时,确定当前的神经网络模型为训练完成后的神经网络模型。该方法能够提高网络收敛速度、泛化能力及算法性能。
Description
技术领域
本发明涉及深度学习技术领域,具体而言,涉及一种神经网络模型的训练方法、装置、设备及可读存储介质。
背景技术
随着以卷积神经网络为代表的深度学习技术广泛应用于各个领域,深度学习技术在透射图像领域也越来越成为研究热点。但是由于透射图像成像方式及待拍摄货物的不同视角、堆叠方式、重叠等因素,此外还有很多同一货物图像的形态、规格千差万别,导致在对透射图像进行训练时神经网络难以收敛或者神经网络收敛的很好但是泛化能力较弱等问题。
目前,在面对目标数量大及类别细粒度难以区分的问题时,度量学习方法是常用的神经网络训练方法。一个神经网络性能的训练好坏,损失函数占着举足轻重的作用,在度量学习方法中Triplet loss(三元组损失函数)是常用的损失函数,其利用“使同一类别在空间中靠近,不同类别在空间中远离”的思想来训练神经网络,进而来实现不同类别目标的区分。
然而,传统的Triplet(三元组)网络,在面对大量数据集训练模型时,会存在部分三元组是由容易的正例和负例组成的,对Triplet网络的训练过程价值较小,从而影响网络收敛速度和泛化能力的问题。
此外,传统的三元组网络,在面对同类样本差异大时,还存在由于某类的聚类空间过大而导致正负样本难以分开,影响算法性能的问题。
在所述背景技术部分公开的上述信息仅用于加强对本发明的背景的理解,因此它可以包括不构成对本领域普通技术人员已知的现有技术的信息。
发明内容
本发明提供一种神经网络模型的训练方法、装置、设备及可读存储介质,能够提高网络收敛速度、泛化能力及算法性能。
本发明的其他特性和优点将通过下面的详细描述变得显然,或部分地通过本发明的实践而习得。
根据本发明的一方面,提供一种神经网络模型的训练方法,包括:获取图像训练数据集;其中,图像训练数据集包括N类待训练图像,N类待训练图像中的每一类包括M张待训练图像,N、M均为正整数;根据图像训练数据集,创建小批量样本集;其中,小批量样本集包括从N类待训练图像中选取的n类待训练图像,n类待训练图像中的每一类包括m张待训练图像,n、m均为正整数且n<N、m<M;获取小批量样本集的特征集合{f1,f2,f3,……,fs};其中,s=n*m;依次以小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征;根据每一个锚点的特征及其对应的难正例的特征与难负例的特征,确定小批量样本集的三元组损失;根据小批量样本集的三元组损失,训练神经网络模型;以及当小批量样本集的三元组损失的曲线收敛到预期状态时,确定当前的神经网络模型为训练完成后的神经网络模型。
根据本发明的一实施方式,确定小批量样本集的三元组损失包括:根据下述公式计算小批量样本集的三元组损失:
其中,Triloss为小批量样本集的三元组损失,为锚点的特征,为难正例的特征,为难负例的特征,α为约束难正例和难负例之间距离的阈值参数,β为用于约束同类待训练图像的类内距离的约束参数,τ为用于约束同类待训练图像的特征的聚类空间的约束参数。
根据本发明的一实施方式,为每个锚点在特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征包括:在小批量样本集中,确定与锚点属于同类的各同类待训练图像;确定锚点的特征与各同类待训练图像的特征之间的类内距离构成的类内距离集合;确定类内距离集合中最大的类内距离值对应的同类待训练图像的特征为锚点对应的难正例的特征;在小批量样本集中,确定与锚点属于不同类的各不同类待训练图像;确定锚点的特征与各不同类待训练图像的特征之间的类间距离构成的类间距离集合;以及确定类间距离集合中最小的类内距离值对应的不同类待训练图像的特征为锚点对应的难负例的特征。
根据本发明的一实施方式,上述方法还包括:对各待训练图像分别进行图像预处理操作。
根据本发明的一实施方式,图像预处理操作包括:以固定尺寸,对待训练图像进行图像缩放处理;对缩放后的待训练图像进行滤波处理;以及转换滤波后的待训练图像的数据类型。
根据本发明的一实施方式,获得小批量样本集的特征集合{f1,f2,f3,……,fs}包括:基于深度神经网络对小批量样本集中的待训练图像进行特征提取,获得小批量样本集的第一特征集合{F1,F2,F3,……,Fs};以及根据小批量样本集的第一特征集合{F1,F2,F3,……,Fs},确定小批量样本集的特征集合{f1,f2,f3,……,fs};其中,特征集合{f1,f2,f3,……,fs}中的各特征为第一特征集合{F1,F2,F3,……,Fs}中各第一特征的归一化特征。
根据本发明的一实施方式,根据下述公式计算第一特征的归一化特征:
其中,Fj为第一特征,Fj(x)为第一特征中的任一像素值,F'(x)为归一化特征。
根据本发明的另一方面,提供一种神经网络模型的训练装置,包括:训练集获取模块,用于获取图像训练数据集;其中,图像训练数据集包括N类待训练图像,N类待训练图像中的每一类包括M张待训练图像,N、M均为正整数;样本集创建模块,用于根据图像训练数据集,创建小批量样本集;其中,小批量样本集包括从N类待训练图像中选取的n类待训练图像,n类待训练图像中的每一类包括m张待训练图像,n、m均为正整数且n<N、m<M;特征集获取模块,用于获取小批量样本集的特征集合{f1,f2,f3,……,fs};其中,s=n*m;三元组确定模块,用于依次以小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征;损失确定模块,用于根据每一个锚点的特征及其对应的难正例的特征与难负例的特征,确定小批量样本集的三元组损失;模型训练模块,用于根据小批量样本集的三元组损失,训练神经网络模型;以及模型确定模块,用于当小批量样本集的三元组损失的曲线收敛到预期状态时,确定当前的神经网络模型为训练完成后的神经网络模型。
根据本发明的再一方面,提供一种计算机设备,包括:存储器、处理器及存储在存储器中并可在处理器中运行的可执行指令,处理器执行可执行指令时实现如上述任一种神经网络模型的训练方法。
根据本发明的再一方面,提供一种计算机可读存储介质,其上存储有计算机可执行指令,可执行指令被处理器执行时实现如上述任一种神经网络模型的训练方法。
根据本发明实施方式提供的神经网络模型的训练方法,利用难例挖掘的采样策略训练样本数据,从而能够加快网络收敛速度、提高网络泛化能力。
此外,根据本发明实施方式的另一些实施例,利用改进的损耗计算方法来反馈优化神经网络,可以提高神经网络训练性能,使同类样本特征内聚性更高,不同类样本特征耦合性更低。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性的,并不能限制本发明。
附图说明
通过参照附图详细描述其示例实施例,本发明的上述和其它目标、特征及优点将变得更加显而易见。
图1是根据一示例性实施方式示出的一种神经网络模型的训练方法的流程图。
图2是根据一示例性实施方式示出的另一种神经网络模型的训练方法的流程图。
图3是根据一示例性实施方式示出的再一种神经网络模型的训练方法的流程图。
图4是根据一示例性实施方式示出的一种神经网络模型的训练装置的框图。
图5是根据一示例性实施方式示出的一种电子设备的结构示意图。
图6是根据一示例性实施方式示出的一种计算机可读存储介质的示意图。
具体实施方式
现在将参考附图更全面地描述示例实施方式。然而,示例实施方式能够以多种形式实施,且不应被理解为限于在此阐述的范例;相反,提供这些实施方式使得本发明将更加全面和完整,并将示例实施方式的构思全面地传达给本领域的技术人员。附图仅为本发明的示意性图解,并非一定是按比例绘制。图中相同的附图标记表示相同或类似的部分,因而将省略对它们的重复描述。
此外,所描述的特征、结构或特性可以以任何合适的方式结合在一个或更多实施方式中。在下面的描述中,提供许多具体细节从而给出对本发明的实施方式的充分理解。然而,本领域技术人员将意识到,可以实践本发明的技术方案而省略所述特定细节中的一个或更多,或者可以采用其它的方法、组元、装置、步骤等。在其它情况下,不详细示出或描述公知结构、方法、装置、实现或者操作以避免喧宾夺主而使得本发明的各方面变得模糊。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本发明的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。
图1是根据一示例性实施方式示出的一种神经网络模型的训练方法的流程图。
参考图1,神经网络模型的训练方法10包括:
在步骤S102中,获取图像训练数据集。
其中,图像训练数据集包括N类待训练图像,N类待训练图像中的每一类包括M张待训练图像,N、M均为正整数。
待训练图像例如可以为透视图像。
在步骤S104中,根据图像训练数据集,创建小批量(mini-batch)样本集。
其中,小批量样本集包括从N类待训练图像中选取的n类待训练图像,n类待训练图像中的每一类包括m张待训练图像,n、m均为正整数且n<N、m<M。
例如,可以随机地从N类待训练图像中选取n类待训练图像作为小批量样本集。
在步骤S106中,获取小批量样本集的特征集合{f1,f2,f3,……,fs}。
其中,下标s=n*m,也即特征集合{f1,f2,f3,……,fs}中的每个特征分别对应于小批量样本集中的n*m个待训练图像。
在步骤S108中,依次以小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征。
在三元组网络训练过程中,基于三元组数据计算其损失函数。在一个三元组中包括:锚点(Anchor)、正样本(Positive)、负样本(Negative),其中正样本是与锚点同一类目标,负样本是与锚点不同类的目标。
为了克服三元组网络训练过程中部分容易识别的正负样本对网络性能贡献低导致网络收敛速度慢及泛化能力弱的问题,引入难例挖掘(hard sample mining)的步骤,分别以小批量样本集中的每一个待训练图像作为锚点,确定其对应的难正例(即正样本)与难负例(即负样本),并在特征集合{f1,f2,f3,……,fs}中确定其各难正例的特征与各难负例的特征。需要说明的是,本发明实施方式选择小批量样本集(较难的样本)作为训练样本,一方面是因为过难的样本集(如大批量样本集)的三元组会导致训练不稳定收敛变难,另一方面受计算机内存限制,过大的批量会使得网络无法训练。
在步骤S110中,根据每一个锚点的特征及其对应的难正例的特征与难负例的特征,确定小批量样本集的三元组损失。
在一些实施例中,可以采用传统的三元组损耗计算方法,根据下述公式(1)计算小批量样本集的三元组损失:
其中,Triloss为小批量样本集的三元组损失,为锚点的特征,为难正例的特征,为难负例的特征,α为约束难正例和难负例之间距离的阈值参数,()+表示当括号内的值为正数时取该值,为负数或0时取0,||…||2为二范数运算。
由于传统的三元组损耗计算方法使类内距离小于类间距离,但是不能保证同类样本的类内聚类空间小,为了解决由于同类样本特征差异大而导致网络难以收敛的问题,本发明还提供了另一些实施例,在三元组的损耗计算中加入约束类内距离部分的损失,例如可以根据下述公式(2)计算小批量样本集的三元组损失:
其中,Triloss为小批量样本集的三元组损失,为锚点的特征,为难正例的特征,为难负例的特征,α为约束难正例和难负例之间距离的阈值参数,β为用于约束同类待训练图像的类内距离的约束参数,例如可以为用于权衡传统的损耗与类内约束损耗所占的比重的参数,τ为用于约束同类待训练图像的特征的聚类空间的约束参数,()+表示当括号内的值为正数时取该值,为负数或0时取0,||…||2为二范数运算。优选地,α可设置为0.5,β可设置为0.1,τ可设置为0.5。但需要说明的是,在实际应用中,各参数α、β及τ可以根据实际需求而设定,本发明不以此为限。
此外,还可以根据下述公式(3)或公式(4)计算小批量样本集的三元组损耗:
各参数及算式含义如上,在此不再赘述。
在步骤S112中,根据小批量样本集的三元组损失,训练神经网络模型。
例如,可以按照Momentum(动量)梯度下降法来训练该神经网络模型,该神经网络例如为三元组神经网络。优选地,可以将Momentum梯度下降法中的Momentum系数设置为0.9,学习效率设置为0.001,学习率衰减设置为0.5,但本发明不以此为限。
在步骤S114中,当该小批量样本集的三元组损失的曲线收敛到预期状态时,确定当前的神经网络模型为训练完成后的神经网络模型。
例如,可以根据TensorBoard可视化方法实时观察三元组损失曲线是否收敛到预期状态,如果没有,则可以重复上述步骤S104~S112,重新选取小批量样本集继续训练该神经网络模型,直至三元组损失曲线收敛到预期的状态。这时停止训练,确定当前的神经网络模型为训练完成后的神经网络模型,并存储训练得到的网络参数。
根据本发明实施方式提供的神经网络模型的训练方法,利用难例挖掘的采样策略训练样本数据,从而能够加快网络收敛速度、提高网络泛化能力。此外,根据本发明实施方式的另一些实施例,利用改进的损耗计算方法来反馈优化神经网络,可以提高神经网络训练性能,使同类样本特征内聚性更高,不同类样本特征耦合性更低。
应清楚地理解,本发明描述了如何形成和使用特定示例,但本发明的原理不限于这些示例的任何细节。相反,基于本发明公开的内容的教导,这些原理能够应用于许多其它实施方式。
图2是根据一示例性实施方式示出的另一种神经网络模型的训练方法的流程图。
参考图2,神经网络模型的训练方法20包括:
在步骤S202中,获取图像训练数据集。
其中,图像训练数据集包括N类待训练图像,N类待训练图像中的每一类包括M张待训练图像,N、M均为正整数。
待训练图像例如可以为透视图像。
在步骤S204中,对各待训练图像分别进行图像预处理操作。
在一些实施例中,图像预处理操作例如可以包括:
步骤1,以固定尺寸,对待训练图像进行图像缩放处理。
例如,可以利用OpenCV2中的库函数resize()将待训练图像缩放到固定尺寸,即缩放到固定的宽和高,其中宽优选设置为225像素,但本发明不以此为限。
步骤2,对缩放后的待训练图像进行滤波处理。
例如,可以利用滤波核大小为3*3的中值滤波方法对缩放后的待训练图像进行滤波,以削弱图像中的噪声。
步骤3,转换滤波后的待训练图像的数据类型。
例如,可以将其数据类型转换为float32类型。
在步骤S206中,根据图像训练数据集,创建小批量样本集。
其中,小批量样本集包括从N类待训练图像中选取的n类待训练图像,n类待训练图像中的每一类包括m张待训练图像,n、m均为正整数且n<N、m<M。
例如,可以随机地从N类待训练图像中选取n类待训练图像作为小批量样本集。
在步骤S208中,基于深度神经网络对小批量样本集中的待训练图像进行特征提取,获得所述小批量样本集的第一特征集合{F1,F2,F3,……,Fs}。
例如,可以采用resnet深度神经网络模型对小批量样本集中的待训练图像进行特征提取,共获得s=n*m个第一特征。也即第一特征集合{F1,F2,F3,……,Fs}中的每个第一特征分别对应于小批量样本集中的n*m个待训练图像。
在步骤S210中,根据小批量样本集的第一特征集合{F1,F2,F3,……,Fs},确定小批量样本集的特征集合{f1,f2,f3,……,fs}。
其中,特征集合{f1,f2,f3,……,fs}中的各特征为第一特征集合{F1,F2,F3,……,Fs}中各第一特征的归一化特征。
例如,可以通过公式(5)计算各第一特征的归一化特征:
其中,Fj为第一特征,Fj(x)为第一特征中的任一像素值,F'(x)为归一化特征。
同样地,下标s=n*m,也即特征集合{f1,f2,f3,……,fs}中的每个特征分别对应于小批量样本集中的n*m个待训练图像。
在步骤S212中,依次以小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征。
为了克服三元组网络训练过程中部分容易识别的正负样本对网络性能贡献低导致网络收敛速度慢及泛化能力弱的问题,引入难例挖掘(hard sample mining)的步骤,分别以小批量样本集中的每一个待训练图像作为锚点,确定其对应的难正例(即正样本)与难负例(即负样本),并在特征集合{f1,f2,f3,……,fs}中确定其各难正例的特征与各难负例的特征。
在步骤S214中,根据上述公式(2),确定小批量样本集的三元组损失。
由于传统的三元组损耗计算方法使类内距离小于类间距离,但是不能保证同类样本的类内聚类空间小,为了解决由于同类样本特征差异大而导致网络难以收敛的问题,本发明实施方式在三元组的损耗计算中进一步加入约束类内距离部分的损失,具体可参见上述的公式(2)。
在步骤S216中,根据小批量样本集的三元组损失,训练神经网络模型。
例如,可以按照Momentum(动量)梯度下降法来训练该神经网络模型,该神经网络例如为三元组神经网络。优选地,可以将Momentum梯度下降法中的Momentum系数设置为0.9,学习效率设置为0.001,学习率衰减设置为0.5,但本发明不以此为限。
在步骤S218中,判断该小批量样本集的三元组损失的曲线是否收敛到预期状态,如果是,进入步骤S220;否则,进入步骤S206。
例如,可以根据TensorBoard可视化方法实时观察该小批量样本集的三元组损失曲线是否收敛到预期状态,如果没有,则可以重复上述步骤S206~S216,重新构建小批量样本集,继续训练该神经网络模型,直至小批量样本集的三元组损失曲线收敛到预期的状态。
在步骤S220中,确定当前的神经网络模型为训练完成后的神经网络模型。
这时停止训练,确定当前的神经网络模型为训练完成后的神经网络模型,并存储训练得到的网络参数。
根据本发明实施方式提供的神经网络模型的训练方法,一方面,利用难例挖掘的采样策略训练样本数据,从而能够加快网络收敛速度、提高网络泛化能力;另一方面,利用改进的损耗计算方法来反馈优化神经网络,可以提高神经网络训练性能,使同类样本特征内聚性更高,不同类样本特征耦合性更低。
图3是根据一示例性实施方式示出的再一种神经网络模型的训练方法的流程图。具体地,图3所示的神经网络模型的训练方法30进一步说明如何确定一个小批量训练集中各锚点的难正例及难负例,及如何根据上述公式(2)计算小批量样本集的三元组损失。
参考图3,神经网络模型的训练方法30包括:
在步骤S302中,依次以小批量样本集中的每个待训练图像为锚点,分别在小批量样本集中确定每个锚点对应的正负样本特征集。
在小批量样本集中,依次确定每个待训练图像为锚点,并定义该锚点在特征集合{f1,f2,f3,……,fs}中对应的特征fa为该锚点的特征。并定义与每个锚点具有相同类别的待训练图像的特征集为正样本特征集A={fp 1,fp 2,fp 3,……,fp m-1},与每个锚点具有不同类别的待训练图像的特征集B={fn 1,fn 2,fn 3,……,fn s-m}。A和B称为一对正负样本特征集,这样一共可以构造出s=n*m对正负样本特征集。
在步骤S304,依次针对每个锚点的特征fa,对其s对正负样本特征图像集进行遍历,分别计算每对正负样本特征集的类内距离集合和类间距离集合。
例如,可以以下述公式(6)计算类内距离和类间距离:
dxy=||fx-fy||2 (6)
其中fx与fy为任意两个特征,||…||2为二范数运算。
类内距离即是根据公式(6)求锚点的特征fa与正样本特征集A={fp 1,fp 2,fp 3,……,fp m-1}内各正样本特征之间的二范数值,即类内距离集合Dap={dap 1,dap 2,dap 3,……,dap m-1}。
类间距离即是根据公式(6)求锚点的特征fa与负样本特征集B={fn 1,fn 2,fn 3,……,fn s-m}内各正样本特征之间的二范数值,即类间距离集合Dan={dan 1,dan 2,dan 3,……,dan s-m}。
在步骤S306中,分别确定每个锚点对应的难正例的特征及难负例的特征。
根据难例挖掘策略,确定每个锚点对应的难正例的特征及难负例的特征。
分别为每个锚点求取其类内距离集合Dap={dap 1,dap 2,dap 3,……,dap m-1}中的最大距离值dap max,则最大距离值dap max所对应的特征fp即为该锚点对应的难正例的特征。
分别为每个锚点求取其类间距离集合Dan={dan 1,dan 2,dan 3,……,dan s-m}中的最小距离值dan min,则最小距离值dan min所对应的特征fn即为该锚点对应的的难负例的特征。
在步骤S308中,分别根据每个锚点的特征及其对应的难正例的特征与难负例的特征,确定其三元组损失。
针对每个锚点,根据下述公式(7)计算其三元组损失:
L=(||fa-fp||2-||fa-fn||2+α)++β*(||fa-fp||2-τ)+ (7)
其中,L为三元组损失,fa为锚点的特征,fp为难正例的特征,fn为难负例的特征,α为约束难正例和难负例之间距离的阈值参数,β为用于约束同类待训练图像的类内距离的约束参数,例如可以为用于权衡传统的损耗与类内约束损耗所占的比重的参数,τ为用于约束同类待训练图像的特征的聚类空间的约束参数,()+表示当括号内的值为正数时取该值,为负数或0时取0,||…||2为二范数运算。优选地,α可设置为0.5,β可设置为0.1,τ可设置为0.5。但需要说明的是,在实际应用中,各参数α、β及τ可以根据实际需求而设定,本发明不以此为限。
将每个锚点的三元组损失集合记为{L1,L2,L3,……,Ls},其中s=n*m。
在步骤S310中,根据每个锚点的三元组损失,确定该小批量样本集的三元组损失。
根据下述公式(8),计算该小批量样本集的三元组损失:
其中,Triloss为该小批量样本集的三元组损失,Li为根据各锚点的特征及其对应的难正例的特征与难负例的特征计算的三元组损失,s=n*m。
本领域技术人员可以理解实现上述实施方式的全部或部分步骤被实现为由CPU执行的计算机程序。在该计算机程序被CPU执行时,执行本发明提供的上述方法所限定的上述功能。所述的程序可以存储于一种计算机可读存储介质中,该存储介质可以是只读存储器,磁盘或光盘等。
此外,需要注意的是,上述附图仅是根据本发明示例性实施方式的方法所包括的处理的示意性说明,而不是限制目的。易于理解,上述附图所示的处理并不表明或限制这些处理的时间顺序。另外,也易于理解,这些处理可以是例如在多个模块中同步或异步执行的。
下述为本发明装置实施例,可以用于执行本发明方法实施例。对于本发明装置实施例中未披露的细节,请参照本发明方法实施例。
图4是根据一示例性实施方式示出的一种神经网络模型的训练装置的框图。
参考图4,神经网络模型的训练装置40包括:训练集获取模块402、样本集创建模块404、特征集获取模块406、三元组确定模块408、损失确定模块410、模型训练模块412及模型确定模块414。
训练集获取模块402用于获取图像训练数据集。
其中,图像训练数据集包括N类待训练图像,N类待训练图像中的每一类包括M张待训练图像,N、M均为正整数。
样本集创建模块404用于根据图像训练数据集,创建小批量样本集。
其中,小批量样本集包括从N类待训练图像中选取的n类待训练图像,n类待训练图像中的每一类包括m张待训练图像,n、m均为正整数且n<N、m<M。
特征集获取模块406用于获取小批量样本集的特征集合{f1,f2,f3,……,fs}。
其中,s=n*m;
三元组确定模块408用于依次以小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征。
损失确定模块410用于根据每一个锚点的特征及其对应的难正例的特征与难负例的特征,确定小批量样本集的三元组损失。
模型训练模块412用于根据小批量样本集的三元组损失,训练神经网络模型。
模型确定模块414用于当小批量样本集的三元组损失的曲线收敛到预期状态时,确定当前的神经网络模型为训练完成后的神经网络模型。
在一些实施例中,损失确定模块410包括:损失确定单元,用于根据下述公式计算三元组损失:
其中,Triloss为三元组损失,为锚点的特征,为难正例的特征,为难负例的特征,α为约束难正例和难负例之间距离的阈值参数,β为用于约束同类待训练图像的类内距离的约束参数,τ为用于约束同类待训练图像的特征的聚类空间的约束参数。
在一些实施例中,三元组确定模块408包括:同类图像确定单元、类内距离确定单元、难正例确定单元、异类图像确定单元、类间距离确定单元及难负例确定单元。其中,同类图像确定单元用于在小批量样本集中,确定与锚点属于同类的各同类待训练图像。类内距离确定单元用于确定锚点的特征与各同类待训练图像的特征之间的类内距离构成的类内距离集合。难正例确定单元用于确定类内距离集合中最大的类内距离值对应的同类待训练图像的特征为锚点对应的难正例的特征。异类图像确定单元用于在小批量样本集中,确定与锚点属于不同类的各不同类待训练图像。类间距离确定单元用于确定锚点的特征与各不同类待训练图像的特征之间的类间距离构成的类间距离集合。难负例确定单元用于确定类间距离集合中最小的类内距离值对应的不同类待训练图像的特征为锚点对应的难负例的特征。
在一些实施例中,装置40还包括:图像预处理模块,用于对各待训练图像分别进行图像预处理操作。
在一些实施例中,图像预处理模块包括:图像缩放单元、图像滤波单元及类型转换单元。其中,图像缩放单元用于以固定尺寸,对待训练图像进行图像缩放处理。图像滤波单元用于对缩放后的待训练图像进行滤波处理。类型转换单元用于转换滤波后的待训练图像的数据类型。
在一些实施例中,特征集获取模块406包括:第一特征获取单元及归一化处理单元。其中,第一特征获取单元用于基于深度神经网络对小批量样本集中的待训练图像进行特征提取,获得小批量样本集的第一特征集合{F1,F2,F3,……,Fs}。归一化处理单元用于根据小批量样本集的第一特征集合{F1,F2,F3,……,Fs},确定小批量样本集的特征集合{f1,f2,f3,……,fs};其中,特征集合{f1,f2,f3,……,fs}中的各特征为第一特征集合{F1,F2,F3,……,Fs}中各第一特征的归一化特征。
在一些实施例中,归一化处理单元根据下述公式计算第一特征的归一化特征:
其中,Fj为第一特征,Fj(x)为第一特征中的任一像素值,F'(x)为归一化特征。
根据本发明实施方式提供的神经网络模型的训练装置,利用难例挖掘的采样策略训练样本数据,从而能够加快网络收敛速度、提高网络泛化能力。此外,根据本发明实施方式的另一些实施例,利用改进的损耗计算方法来反馈优化神经网络,可以提高神经网络训练性能,使同类样本特征内聚性更高,不同类样本特征耦合性更低。
需要注意的是,上述附图中所示的框图是功能实体,不一定必须与物理或逻辑上独立的实体相对应。可以采用软件形式来实现这些功能实体,或在一个或多个硬件模块或集成电路中实现这些功能实体,或在不同网络和/或处理器装置和/或微控制器装置中实现这些功能实体。
图5是根据一示例性实施方式示出的一种电子设备的结构示意图。需要说明的是,图5示出的电子设备仅仅是一个示例,不应对本发明实施例的功能和使用范围带来任何限制。
如图5所示,电子设备800以通用计算机设备的形式表现。电子设备800的组件包括:至少一个中央处理单元(CPU)801,其可以根据存储在只读存储器(ROM)802中的程序代码或者从至少一个存储单元808加载到随机访问存储器(RAM)803中的程序代码而执行各种适当的动作和处理。
特别地,根据本发明的实施例,所述程序代码可以被中央处理单元801执行,使得中央处理单元801执行本说明书上述方法实施例部分中描述的根据本发明各种示例性实施方式的步骤。例如,中央处理单元801可以执行如图1-图3中所示的步骤。
在RAM 803中,还存储有电子设备800操作所需的各种程序和数据。CPU 801、ROM802以及RAM 803通过总线804彼此相连。输入/输出(I/O)接口805也连接至总线804。
以下部件连接至I/O接口805:包括键盘、鼠标等的输入单元806;包括诸如阴极射线管(CRT)、液晶显示器(LCD)等以及扬声器等的输出单元807;包括硬盘等的存储单元808;以及包括诸如LAN卡、调制解调器等的网络接口卡的通信单元809。通信单元809经由诸如因特网的网络执行通信处理。驱动器810也根据需要连接至I/O接口805。可拆卸介质811,诸如磁盘、光盘、磁光盘、半导体存储器等等,根据需要安装在驱动器810上,以便于从其上读出的计算机程序根据需要被安装入存储单元808。
图6是根据一示例性实施方式示出的一种计算机可读存储介质的示意图。
参考图6所示,描述了根据本发明的实施方式的设置为实现上述方法的程序产品900,其可以采用便携式紧凑盘只读存储器(CD-ROM)并包括程序代码,并可以在终端设备,例如个人电脑上运行。然而,本发明的程序产品不限于此,在本文件中,可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
上述计算机可读介质承载有一个或者多个程序,当上述一个或者多个程序被一个该设备执行时,使得该计算机可读介质实现如下功能:
获取图像训练数据集;其中,所述图像训练数据集包括N类待训练图像,所述N类待训练图像中的每一类包括M张待训练图像,N、M均为正整数;
根据所述图像训练数据集,创建小批量样本集;其中,所述小批量样本集包括从所述N类待训练图像中选取的n类待训练图像,所述n类待训练图像中的每一类包括m张待训练图像,n、m均为正整数且n<N、m<M;
获取所述小批量样本集的特征集合{f1,f2,f3,……,fs};其中,s=n*m;
依次以所述小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在所述特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征;
根据每一个所述锚点的特征及其对应的所述难正例的特征与所述难负例的特征,确定三元组损失;
根据所述三元组损失,训练神经网络模型;以及
当所述三元组损失的曲线收敛到预期状态时,确定当前的神经网络模型为训练完成后的神经网络模型。
以上具体地示出和描述了本发明的示例性实施方式。应可理解的是,本发明不限于这里描述的详细结构、设置方式或实现方法;相反,本发明意图涵盖包含在所附权利要求的精神和范围内的各种修改和等效设置。
Claims (10)
1.一种神经网络模型的训练方法,其特征在于,包括:
获取图像训练数据集;其中,所述图像训练数据集包括N类待训练图像,所述N类待训练图像中的每一类包括M张待训练图像,N、M均为正整数;
根据所述图像训练数据集,创建小批量样本集;其中,所述小批量样本集包括从所述N类待训练图像中选取的n类待训练图像,所述n类待训练图像中的每一类包括m张待训练图像,n、m均为正整数且n<N、m<M;
获取所述小批量样本集的特征集合{f1,f2,f3,……,fs};其中,s=n*m;
依次以所述小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在所述特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征;
根据每一个所述锚点的特征及其对应的所述难正例的特征与所述难负例的特征,确定所述小批量样本集的三元组损失;
根据所述小批量样本集的三元组损失,训练神经网络模型;以及
当所述小批量样本集的三元组损失的曲线收敛到预期状态时,确定当前的神经网络模型为训练完成后的神经网络模型。
3.根据权利要求2所述的方法,其特征在于,为每个锚点在所述特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征包括:
在所述小批量样本集中,确定与所述锚点属于同类的各同类待训练图像;
确定所述锚点的特征与各同类待训练图像的特征之间的类内距离构成的类内距离集合;
确定所述类内距离集合中最大的类内距离值对应的所述同类待训练图像的特征为所述锚点对应的难正例的特征;
在所述小批量样本集中,确定与所述锚点属于不同类的各不同类待训练图像;
确定所述锚点的特征与各不同类待训练图像的特征之间的类间距离构成的类间距离集合;以及
确定所述类间距离集合中最小的类内距离值对应的所述不同类待训练图像的特征为所述锚点对应的难负例的特征。
4.根据权利要求2所述的方法,其特征在于,还包括:对各所述待训练图像分别进行图像预处理操作。
5.根据权利要求4所述的方法,其特征在于,所述图像预处理操作包括:
以固定尺寸,对所述待训练图像进行图像缩放处理;
对缩放后的所述待训练图像进行滤波处理;以及
转换滤波后的所述待训练图像的数据类型。
6.根据权利要求2所述的方法,其特征在于,获得所述小批量样本集的特征集合{f1,f2,f3,……,fs}包括:
基于深度神经网络对所述小批量样本集中的待训练图像进行特征提取,获得所述小批量样本集的第一特征集合{F1,F2,F3,……,Fs};以及
根据所述小批量样本集的第一特征集合{F1,F2,F3,……,Fs},确定所述小批量样本集的特征集合{f1,f2,f3,……,fs};
其中,所述特征集合{f1,f2,f3,……,fs}中的各特征为所述第一特征集合{F1,F2,F3,……,Fs}中各第一特征的归一化特征。
8.一种神经网络模型的训练装置,其特征在于,包括:
训练集获取模块,用于获取图像训练数据集;其中,所述图像训练数据集包括N类待训练图像,所述N类待训练图像中的每一类包括M张待训练图像,N、M均为正整数;
样本集创建模块,用于根据所述图像训练数据集,创建小批量样本集;其中,所述小批量样本集包括从所述N类待训练图像中选取的n类待训练图像,所述n类待训练图像中的每一类包括m张待训练图像,n、m均为正整数且n<N、m<M;
特征集获取模块,用于获取所述小批量样本集的特征集合{f1,f2,f3,……,fs};其中,s=n*m;
三元组确定模块,用于依次以所述小批量样本集中的每一个待训练图像作为锚点,分别为每个锚点在所述特征集合{f1,f2,f3,……,fs}中确定其对应的难正例的特征与难负例的特征;
损失确定模块,用于根据每一个所述锚点的特征及其对应的所述难正例的特征与所述难负例的特征,确定所述小批量样本集的三元组损失;
模型训练模块,用于根据所述小批量样本集的三元组损失,训练神经网络模型;以及
模型确定模块,用于当所述小批量样本集的三元组损失的曲线收敛到预期状态时,确定当前的神经网络模型为训练完成后的神经网络模型。
9.一种计算机设备,包括:存储器、处理器及存储在所述存储器中并可在所述处理器中运行的可执行指令,其特征在于,所述处理器执行所述可执行指令时实现如权利要求1-7任一项所述的方法。
10.一种计算机可读存储介质,其上存储有计算机可执行指令,其特征在于,所述可执行指令被处理器执行时实现如权利要求1-7任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811592198.1A CN111368989B (zh) | 2018-12-25 | 2018-12-25 | 神经网络模型的训练方法、装置、设备及可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811592198.1A CN111368989B (zh) | 2018-12-25 | 2018-12-25 | 神经网络模型的训练方法、装置、设备及可读存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111368989A true CN111368989A (zh) | 2020-07-03 |
CN111368989B CN111368989B (zh) | 2023-06-16 |
Family
ID=71211477
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201811592198.1A Active CN111368989B (zh) | 2018-12-25 | 2018-12-25 | 神经网络模型的训练方法、装置、设备及可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111368989B (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111931698A (zh) * | 2020-09-08 | 2020-11-13 | 平安国际智慧城市科技股份有限公司 | 基于小型训练集的图像深度学习网络构建方法及装置 |
CN113239217A (zh) * | 2021-06-04 | 2021-08-10 | 图灵深视(南京)科技有限公司 | 图像索引库构建方法及系统,图像检索方法及系统 |
CN113486914A (zh) * | 2020-10-20 | 2021-10-08 | 腾讯科技(深圳)有限公司 | 图像特征提取的神经网络训练的方法、装置和存储介质 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180114055A1 (en) * | 2016-10-25 | 2018-04-26 | VMAXX. Inc. | Point to Set Similarity Comparison and Deep Feature Learning for Visual Recognition |
CN108304864A (zh) * | 2018-01-17 | 2018-07-20 | 清华大学 | 深度对抗度量学习方法及装置 |
CN108399428A (zh) * | 2018-02-09 | 2018-08-14 | 哈尔滨工业大学深圳研究生院 | 一种基于迹比准则的三元组损失函数设计方法 |
CN108647577A (zh) * | 2018-04-10 | 2018-10-12 | 华中科技大学 | 一种自适应难例挖掘的行人重识别模型、方法与系统 |
-
2018
- 2018-12-25 CN CN201811592198.1A patent/CN111368989B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180114055A1 (en) * | 2016-10-25 | 2018-04-26 | VMAXX. Inc. | Point to Set Similarity Comparison and Deep Feature Learning for Visual Recognition |
CN108304864A (zh) * | 2018-01-17 | 2018-07-20 | 清华大学 | 深度对抗度量学习方法及装置 |
CN108399428A (zh) * | 2018-02-09 | 2018-08-14 | 哈尔滨工业大学深圳研究生院 | 一种基于迹比准则的三元组损失函数设计方法 |
CN108647577A (zh) * | 2018-04-10 | 2018-10-12 | 华中科技大学 | 一种自适应难例挖掘的行人重识别模型、方法与系统 |
Non-Patent Citations (1)
Title |
---|
周华捷 等: "深度学习下的行人再识别问题研究" * |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111931698A (zh) * | 2020-09-08 | 2020-11-13 | 平安国际智慧城市科技股份有限公司 | 基于小型训练集的图像深度学习网络构建方法及装置 |
CN113486914A (zh) * | 2020-10-20 | 2021-10-08 | 腾讯科技(深圳)有限公司 | 图像特征提取的神经网络训练的方法、装置和存储介质 |
CN113486914B (zh) * | 2020-10-20 | 2024-03-05 | 腾讯科技(深圳)有限公司 | 图像特征提取的神经网络训练的方法、装置和存储介质 |
CN113239217A (zh) * | 2021-06-04 | 2021-08-10 | 图灵深视(南京)科技有限公司 | 图像索引库构建方法及系统,图像检索方法及系统 |
CN113239217B (zh) * | 2021-06-04 | 2024-02-06 | 图灵深视(南京)科技有限公司 | 图像索引库构建方法及系统,图像检索方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN111368989B (zh) | 2023-06-16 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109146892B (zh) | 一种基于美学的图像裁剪方法及装置 | |
WO2021098362A1 (zh) | 视频分类模型构建、视频分类的方法、装置、设备及介质 | |
WO2020224403A1 (zh) | 分类任务模型的训练方法、装置、设备及存储介质 | |
CN108898186B (zh) | 用于提取图像的方法和装置 | |
US20220147822A1 (en) | Training method and apparatus for target detection model, device and storage medium | |
WO2020224405A1 (zh) | 图像处理方法、装置、计算机可读介质及电子设备 | |
EP4053751A1 (en) | Method and apparatus for training cross-modal retrieval model, device and storage medium | |
US9025889B2 (en) | Method, apparatus and computer program product for providing pattern detection with unknown noise levels | |
CN111368989B (zh) | 神经网络模型的训练方法、装置、设备及可读存储介质 | |
CN103295022B (zh) | 图像相似度计算系统及方法 | |
JP2009282980A (ja) | 画像学習、自動注釈、検索方法及び装置 | |
CN113205041B (zh) | 结构化信息提取方法、装置、设备和存储介质 | |
US20220245465A1 (en) | Picture searching method and apparatus, electronic device and computer readable storage medium | |
CN110458133A (zh) | 基于生成式对抗网络的轻量级人脸检测方法 | |
CN111144566B (zh) | 神经网络权重参数的训练方法、特征分类方法及对应装置 | |
WO2022257614A1 (zh) | 物体检测模型的训练方法、图像检测方法及其装置 | |
CN110399826B (zh) | 一种端到端人脸检测和识别方法 | |
CN112508782A (zh) | 网络模型的训练方法、人脸图像超分辨率重建方法及设备 | |
CN115861462B (zh) | 图像生成模型的训练方法、装置、电子设备及存储介质 | |
WO2023147717A1 (zh) | 文字检测方法、装置、电子设备和存储介质 | |
US9081800B2 (en) | Object detection via visual search | |
CN112036316B (zh) | 手指静脉识别方法、装置、电子设备及可读存储介质 | |
CN113673465A (zh) | 图像检测方法、装置、设备及可读存储介质 | |
CN117435896A (zh) | 一种在不平衡分类场景下无分割的验证集合成方法 | |
WO2020252898A1 (zh) | 眼底oct影像增强方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |