CN111325223B - 深度学习模型的训练方法、装置和计算机可读存储介质 - Google Patents

深度学习模型的训练方法、装置和计算机可读存储介质 Download PDF

Info

Publication number
CN111325223B
CN111325223B CN201811521621.9A CN201811521621A CN111325223B CN 111325223 B CN111325223 B CN 111325223B CN 201811521621 A CN201811521621 A CN 201811521621A CN 111325223 B CN111325223 B CN 111325223B
Authority
CN
China
Prior art keywords
loss function
function value
training
sample
determining
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN201811521621.9A
Other languages
English (en)
Other versions
CN111325223A (zh
Inventor
李旭锟
张信豪
杜鹏
邹洪亮
李明
任新新
汪庆寿
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
China Telecom Corp Ltd
Original Assignee
China Telecom Corp Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by China Telecom Corp Ltd filed Critical China Telecom Corp Ltd
Priority to CN201811521621.9A priority Critical patent/CN111325223B/zh
Publication of CN111325223A publication Critical patent/CN111325223A/zh
Application granted granted Critical
Publication of CN111325223B publication Critical patent/CN111325223B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本公开涉及一种深度学习模型的训练方法、装置和计算机可读存储介质,涉及计算机技术领域。本公开的方法包括:将训练样本输入待训练的深度学习模型,训练样本包括:锚样本、正样本和负样本;根据输出的训练样本的特征与对应的类中心的特征的距离,锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定损失函数值;根据损失函数值对待训练的深度学习模型的参数进行调整,以便完成对待训练的深度学习模型的训练。本公开的方案加快了训练的收敛速度,提高训练效率。

Description

深度学习模型的训练方法、装置和计算机可读存储介质
技术领域
本公开涉及计算机技术领域,特别涉及一种深度学习模型的训练方法、装置和计算机可读存储介质。
背景技术
近几年深度学习在安防、教育、医疗健康、金融等领域取得了突破性的进展,例如语音识别、图像识别等等。可以说到目前为止,深度学习是最接近人类大脑的智能学习方法。但是深度学习模型参数多,计算量大,训练数据的规模也更大,在一些项目中,往往需要训练几个月甚至更长,这大大降低了训练的效率,因此如何加快训练速度也是深度学习中一个亟待解决的问题。
图像识别、人脸识别等是目前研究热度很高的领域,大多图像识别、人脸识别模型都应用了深度学习技术。在深度学习模型的训练过程中,常用的损失函数有Triplet Loss(三元组损失)。将训练样本输入深度学习模型,通过计算Triplet Loss调整模型的参数,完成模型的训练。
发明内容
发明人发现:采用Triplet Loss对深度学习模型进行实际训练时,收敛速度不快,训练效率较低。
本公开所要解决的一个技术问题是:提高深度学习模型的训练效率。
根据本公开的一些实施例,提供的一种深度学习模型的训练方法,包括:将训练样本输入待训练的深度学习模型,训练样本包括:锚样本、正样本和负样本;根据输出的训练样本的特征与对应的类中心的特征的距离,锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定损失函数值;根据损失函数值对待训练的深度学习模型的参数进行调整,以便完成对待训练的深度学习模型的训练。
在一些实施例中,确定损失函数值的方法包括:根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;将第一损失函数值与第二损失函数值的加权和确定为损失函数值。
在一些实施例中,确定损失函数值的方法还包括:计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;在差值超过预设范围内的情况下,将第一损失函数值与第二损失函数值的加权和确定为当前训练周期的损失函数值;或者,在差值在预设范围内的情况下,将第二损失函数值确定为当前训练周期的损失函数值。
在一些实施例中,确定损失函数值的方法包括:根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;根据输出的类中心的特征与训练样本中心的特征的距离,确定第三损失函数值;将第一损失函数值、第二损失函数值与第三损失函数值的加权和确定为损失函数值。
在一些实施例中,确定损失函数值的方法还包括:计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;在差值超过预设范围内的情况下,将第一损失函数值、第二损失函数值与第三损失函数值的加权和确定为当前训练周期的损失函数值;或者,在差值在预设范围内的情况下,将第二损失函数值确定为当前训练周期的损失函数值。
在一些实施例中,第一损失函数值采用以下公式确定:
其中,m表示输入的训练样本的数量,i表示训练样本的编号,1≤i≤m,i为正整数,g(xi)表示第i个训练样本xi的特征,表示第i个训练样本xi对应的类别yi的类中心的特征;
或者,第二损失函数值采用以下公式确定:
其中,j表示训练样本三元组的编号,每个三元组中包含一个锚样本,一个正样本和一个负样本,表示第j个三元组中锚样本的特征,/>第j个三元组中正样本的特征,/>表示第j个三元组中负样本的特征,α为常数;
或者,第三损失函数值采用以下公式确定:
其中,cbc表示训练样本中心的特征。
在一些实施例中,第一损失函数值对应的权重随着训练周期数量的增加而减小;或者第三损失函数值对应的权重随着训练周期数量的增加而减小。
根据本公开的另一些实施例,提供的一种深度学习模型的训练装置,包括:输入模块,用于将训练样本输入待训练的深度学习模型,训练样本包括:锚样本、正样本和负样本;损失函数确定模块,用于根据输出的训练样本的特征与对应的类中心的特征的距离,锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定损失函数值;调整模块,用于根据损失函数值对待训练的深度学习模型的参数进行调整,以便完成对待训练的深度学习模型的训练。
在一些实施例中,损失函数确定模块用于根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;将第一损失函数值与第二损失函数值的加权和确定为损失函数值。
在一些实施例中,损失函数确定模块用于计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;在差值超过预设范围内的情况下,将第一损失函数值与第二损失函数值的加权和确定为当前训练周期的损失函数值;或者,在差值在预设范围内的情况下,将第二损失函数值确定为当前训练周期的损失函数值。
在一些实施例中,损失函数确定模块用于根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;根据输出的类中心的特征与训练样本中心的特征的距离,确定第三损失函数值;将第一损失函数值、第二损失函数值与第三损失函数值的加权和确定为损失函数值。
在一些实施例中,损失函数确定模块用于计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;在差值超过预设范围内的情况下,将第一损失函数值、第二损失函数值与第三损失函数值的加权和确定为当前训练周期的损失函数值;或者,在差值在预设范围内的情况下,将第二损失函数值确定为当前训练周期的损失函数值。
在一些实施例中,第一损失函数值采用以下公式确定:
其中,m表示输入的训练样本的数量,i表示训练样本的编号,1≤i≤m,i为正整数,g(xi)表示第i个训练样本xi的特征,表示第i个训练样本xi对应的类别yi的类中心的特征;
或者,第二损失函数值采用以下公式确定:
其中,j表示训练样本三元组的编号,每个三元组中包含一个锚样本,一个正样本和一个负样本,表示第j个三元组中锚样本的特征,/>第j个三元组中正样本的特征,/>表示第j个三元组中负样本的特征,α为常数;
或者,第三损失函数值采用以下公式确定:
其中,cbc表示训练样本中心的特征。
在一些实施例中,第一损失函数值对应的权重随着训练周期数量的增加而减小;或者第三损失函数值对应的权重随着训练周期数量的增加而减小。
根据本公开的又一些实施例,提供的一种深度学习模型的训练装置,包括:存储器;以及耦接至存储器的处理器,处理器被配置为基于存储在存储器中的指令,执行如前述任意实施例的深度学习模型的训练方法。
根据本公开的再一些实施例,提供的一种计算机可读存储介质,其上存储有计算机程序,其中,该程序被处理器执行时实现前述任意实施例的深度学习模型的训练方法。
本公开中改进了损失函数的计算方法,在深度学习模型的训练过程中,根据输出的训练样本的特征与类中心的特征的距离,锚样本的特征与正样本的特征的距离,以及锚样本的特征与负样本的特征的距离,确定损失函数值,实现对深度学习模型的训练。由于计算损失函数参考类中心的特征,相当于参考了一个相对稳定的收敛中心,降低了单组数据的偏离造成收敛的偏差,所以收敛过程基本上是按梯度最大的方向进行收敛,减少了部分无效甚至负作用的迭代,加快了训练的收敛速度,提高训练效率。
通过以下参照附图对本公开的示例性实施例的详细描述,本公开的其它特征及其优点将会变得清楚。
附图说明
为了更清楚地说明本公开实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出本公开的一些实施例的深度学习模型的训练方法的流程示意图。
图2示出本公开的另一些实施例的深度学习模型的训练方法的流程示意图。
图3示出本公开的又一些实施例的深度学习模型的训练方法的流程示意图。
图4示出本公开的一些实施例的深度学习模型的训练装置的结构示意图。
图5示出本公开的另一些实施例的深度学习模型的训练装置的结构示意图。
图6示出本公开的又一些实施例的深度学习模型的训练装置的结构示意图。
具体实施方式
下面将结合本公开实施例中的附图,对本公开实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本公开一部分实施例,而不是全部的实施例。以下对至少一个示例性实施例的描述实际上仅仅是说明性的,决不作为对本公开及其应用或使用的任何限制。基于本公开中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本公开保护的范围。
针对采用Triplet Loss对深度学习模型进行实际训练时,收敛速度不快,训练效率较低的问题,提出本方案,下面结合图1进行描述。
图1为本公开深度学习模型的训练方法一些实施例的流程图。如图1所示,该实施例的方法包括:步骤S102~S106。
在步骤S102中,将训练样本输入待训练的深度学习模型,训练样本包括:锚样本、正样本和负样本。
例如,深度学习模型为图片识别模型或人脸识别模型的情况下,训练样本可以是图片。人脸识别模型例如为FaceNet等模型。将训练样本划分为不同的三元组,三元组中一个训练样本作为锚(Anchor)样本,一个与锚样本属于同一类别的训练样本作为正(Positive)样本,一个与属于不同类别的训练样本作为负(Negative)样本。上述训练样本的确定过程与利用Triplet Loss进行训练时的训练样本的准备过程相同。
待训练的深度学习模型可以是经过预先训练的深度学习模型,精确度未达到需要的高度,可以采用本公开的方案进行进一步训练,以提高深度学习模型的精确度。深度学习模型训练时,可以将训练样本根据Batch size(批尺寸)划分为不同的Batch(批),每次迭代在深度学习模型中输入一批训练样本进行训练。
在步骤S104中,根据输出的训练样本的特征与对应的类中心的特征的距离,锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定损失函数值。
训练样本输入深度学习模型后可以得到输出的训练样本的特征,特征可以用向量表示,并得到各个训练样本所属的类别,和各个类别的类中心。类中心的特征例如为类别中各个训练样本的特征的均值。
在一些实施例中,根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;将第一损失函数值与第二损失函数值的加权和确定为损失函数值。
例如,第一损失函数值可以采用以下公式确定。
公式(1)中,m表示输入的训练样本的数量,i表示训练样本的编号,1≤i≤m,i为正整数,g(xi)表示第i个训练样本xi的特征,cyi表示第i个训练样本xi所属的类别yi的类中心的特征,表示二范数的平方的运算,g(·)表示深度学习模型中部分子网络的运算函数。
例如,第二损失函数值采用以下公式确定。
公式(2)中,j表示训练样本三元组的编号,每个三元组中包含一个锚样本,一个正样本和一个负样本,表示第j个三元组中锚样本的特征,/>第j个三元组中正样本的特征,/>表示第j个三元组中负样本的特征,α为预设的常数,f(·)表示深度学习模型中部分子网络的运算函数,可以与g(·)相同或不同。公式(2)可以参考Triplet Loss的计算公式。
进一步,最终的损失函数值可以采用以下公式确定。
L=λL1+θL2 (3)
公式(3)中,λ和θ分别为L1和L2的权重,λ和θ可以根据实际训练需求进行设置,例如,λ可以设置为小于1的正数,θ可以设置为1。
为了进一步加快收敛速度,提高深度学习模型训练的效率。可以进一步改进损失函数的计算方法,在一些实施例中,根据输出的训练样本的特征与对应的类中心的特征的距离,类中心的特征与训练样本中心的特征的距离,锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定损失函数值。通过进一步参考类中心的特征与训练样本中心的特征的距离,使用训练样本中心收敛类中心,从而进一步加快了收敛速度和训练效率。训练样本中心的特征例如为输入的训练样本的特征的均值。
在一些实施例中,根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;根据输出的类中心的特征与训练样本中心的特征的距离,确定第三损失函数值;将第一损失函数值、第二损失函数值与第三损失函数值的加权和确定为损失函数值。第一损失函数值和第二损失函数值的计算可以参考上述公式(1)和(2)。第三损失函数值可以采用以下公式确定。
公式(4)中,cbc表示训练样本中心的特征。其他参数参考前述公式(1)-(3)。
进一步,最终的损失函数值可以采用以下公式确定。
公式(5)中,为L3的权重,/>可以根据实际训练需求进行设置,例如,λ可以设置为小于1的正数,θ可以设置为1,/>可以设置为小于1的正数,/>可以与λ相等。
在步骤S106中,根据损失函数值对待训练的深度学习模型的参数进行调整,以便完成对待训练的深度学习模型的训练。
计算得到损失函数值之后,可以参考现有技术的方法对训练的深度学习模型的参数进行调整,例如通过反向传播、梯度下降等方法对深度学习模型的权重进行调整。通过不断的迭代和训练,直至达到训练停止条件,完成深度学习模型的训练。训练停止条件例如为损失函数值不再下降或者损失函数值低于阈值等,可以根据实际需求进行设置。
上述实施例中改进了损失函数的计算方法,在深度学习模型的训练过程中,根据输出的训练样本的特征与类中心的特征的距离,锚样本的特征与正样本的特征的距离,以及锚样本的特征与负样本的特征的距离,确定损失函数值,实现对深度学习模型的训练。由于计算损失函数参考类中心的特征,相当于参考了一个相对稳定的收敛中心,降低了单组数据的偏离造成收敛的偏差,所以收敛过程基本上是按梯度最大的方向进行收敛,减少了部分无效甚至负作用的迭代,加快了训练的收敛速度。
在完成深度学习模型的训练之后,可以将待识别的对象(例如,图片等)输入深度学习模型,确定待识别的对象的类别。
深度学习模型的训练需要经过多个训练周期,每个训练周期例如为一次训练迭代过程,例如包括:输入一批训练样本,正向传播,计算损失函数值,反向传播,使用梯度下降进行反向参数更新等过程。一个训练周期结束后可以选取下一批训练样本,进入下一个训练周期,重复上述各个步骤,直至达到训练停止条件,完成训练。可以根据不同训练周期中损失函数值的变化情况,为不同的训练周期选取合适的损失函数计算方法,下面结合图2描述本公开深度学习模型的训练方法的另一些实施例。
图2为本公开深度学习模型的训练方法另一些实施例的流程图。如图2所示,该实施例的方法包括:步骤S202~S212。
在步骤S202中,选取一批训练样本输入待训练的深度学习模型。
可以每次从所有的训练样本中选取一批预设数量的训练样本,训练样本可以被划分为不同的三元组,包括:锚样本、正样本和负样本。
在步骤S204中,判断当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值是否超过预设范围,如果超过,则执行步骤S206,否则执行步骤S208。
训练样本输入待训练的深度学习模型可以得到训练样本的特征,所属的不同的类别,和各个类别的类中心的特征,第一损失函数值根据训练样本的特征与对应的类中心的特征的距离确定,可以参考前述实施例。在当前周期为第一个训练周期,由于不存在上一周期的第一损失函数值,可以直接将第二损失函数值确定为第一个训练周期的损失函数值,根据第二损失函数值对待训练的深度学习模型的参数进行调整。完成第一个训练周期的训练,之后从步骤S202开始执行。
判断当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值是否超过预设范围可以采用以下公式进行表示。
公式(6)中,β为阈值。
在步骤S206中,将第一损失函数值与第二损失函数值的加权和确定为当前训练周期的损失函数值。
可以参考前述公式(1)~(3)确定当前训练周期的损失函数值。在一些实施例中,第一损失函数值对应的权重随着训练周期数量的增加而减小。例如,可以设置比例系数,每次将第一损失函数值对应的权重乘以比例系数,比例系数为小于1的正数。
在步骤S208中,将第二损失函数值确定为当前训练周期的损失函数值。
如果当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值在预设范围内,表明第一损失函数值的下降幅度很小了,再根据第一损失函数对模型进行调整,对模型的训练效率和精确度的提高则没有太大效果了,这种情况下,仅将第二损失函数值确定为损失函数值,节省计算量。
在步骤S210中,根据当前周期的损失函数值对待训练的深度学习模型的参数进行调整。
在步骤S212中,判读是否达到训练停止条件,如果达到则结束,否则返回步骤S202重新开始执行。
结合前述实施例,损失函数值还可以根据第一损失函数值、第二损失函数值与第三损失函数值确定,下面结合图3描述本公开深度学习模型的训练方法又一些实施例。
图3为本公开深度学习模型的训练方法又一些实施例的流程图。如图3所示,该实施例的方法包括:步骤S302~S312。
在步骤S302中,选取一批训练样本输入待训练的深度学习模型。
可以每次从所有的训练样本中选取一批预设数量的训练样本,训练样本可以被划分为不同的三元组,包括:锚样本、正样本和负样本。
在步骤S304中,判断当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值是否超过预设范围,如果超过,则执行步骤S306,否则执行步骤S308。
在当前周期为第一个训练周期,由于不存在上一周期的第一损失函数值,可以直接将第二损失函数值确定为第一个训练周期的损失函数值,根据第二损失函数值对待训练的深度学习模型的参数进行调整,完成第一个训练周期的训练,之后从步骤S302开始执行。
在步骤S306中,将第一损失函数值、第二损失函数值与第三损失函数值的加权和确定为损失函数值。
可以参考前述公式(1)、(2)、(4)和(5)确定当前训练周期的损失函数值。在一些实施例中,第一损失函数值对应的权重随着训练周期数量的增加而减小;第三损失函数值对应的权重随着训练周期数量的增加而减小。第一损失函数值对应的权重和第三损失函数值对应的权重可以对应相同的比例系数,每次将第一损失函数值对应的权重和第三损失函数值对应的权重乘以比例系数,比例系数为小于1的正数。
在步骤S308中,将第二损失函数值确定为当前训练周期的损失函数值。
在步骤S310中,根据当前周期的损失函数值对待训练的深度学习模型的参数进行调整。
在步骤S312中,判读是否达到训练停止条件,如果达到则结束,否则返回步骤S302重新开始执行。
上述实施例的方法,通过在不同的训练周期判断第一损失函数值是否下降,选取不同的损失函数计算方法,在提高训练效率的情况下,能够节省计算量。
本公开还提供一种深度学习模型的训练装置,下面结合图4进行描述。
图4为本公开深度学习模型的训练装置的一些实施例的结构图。如图4所示,该实施例的装置40包括:输入模块402,损失函数确定模块404,调整模块406。
输入模块402,用于将训练样本输入待训练的深度学习模型,训练样本包括:锚样本、正样本和负样本。
损失函数确定模块404,用于根据输出的训练样本的特征与对应的类中心的特征的距离,锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定损失函数值。
在一些实施例中,损失函数确定模块404用于根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;将第一损失函数值与第二损失函数值的加权和确定为损失函数值。
在一些实施例中,损失函数确定模块404用于计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;在差值超过预设范围内的情况下,将第一损失函数值与第二损失函数值的加权和确定为当前训练周期的损失函数值;或者,在差值在预设范围内的情况下,将第二损失函数值确定为当前训练周期的损失函数值。
在一些实施例中,损失函数确定模块404用于根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;根据输出的类中心的特征与训练样本中心的特征的距离,确定第三损失函数值;将第一损失函数值、第二损失函数值与第三损失函数值的加权和确定为损失函数值。
在一些实施例中,损失函数确定模块404用于计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;在差值超过预设范围内的情况下,将第一损失函数值、第二损失函数值与第三损失函数值的加权和确定为当前训练周期的损失函数值;或者,在差值在预设范围内的情况下,将第二损失函数值确定为当前训练周期的损失函数值。
在一些实施例中,第一损失函数值采用以下公式确定:
其中,m表示输入的训练样本的数量,i表示训练样本的编号,1≤i≤m,i为正整数,g(xi)表示第i个训练样本xi的特征,表示第i个训练样本xi对应的类别yi的类中心的特征;
或者,第二损失函数值采用以下公式确定:
其中,j表示训练样本三元组的编号,每个三元组中包含一个锚样本,一个正样本和一个负样本,表示第j个三元组中锚样本的特征,/>第j个三元组中正样本的特征,/>表示第j个三元组中负样本的特征,α为常数;
或者,第三损失函数值采用以下公式确定:
其中,cbc表示训练样本中心的特征。
在一些实施例中,第一损失函数值对应的权重随着训练周期数量的增加而减小;或者第三损失函数值对应的权重随着训练周期数量的增加而减小。
调整模块406,用于根据损失函数值对待训练的深度学习模型的参数进行调整,以便完成对待训练的深度学习模型的训练。
本公开的实施例中的深度学习模型的训练装置可各由各种计算设备或计算机系统来实现,下面结合图5以及图6进行描述。
图5为本公开深度学习模型的训练装置的一些实施例的结构图。如图5所示,该实施例的装置50包括:存储器510以及耦接至该存储器510的处理器520,处理器520被配置为基于存储在存储器510中的指令,执行本公开中任意一些实施例中的深度学习模型的训练方法。
其中,存储器510例如可以包括系统存储器、固定非易失性存储介质等。系统存储器例如存储有操作系统、应用程序、引导装载程序(Boot Loader)、数据库以及其他程序等。
图6为本公开深度学习模型的训练装置的另一些实施例的结构图。如图6所示,该实施例的装置60包括:存储器610以及处理器620,分别与存储器510以及处理器520类似。还可以包括输入输出接口630、网络接口640、存储接口650等。这些接口630,640,650以及存储器610和处理器620之间例如可以通过总线660连接。其中,输入输出接口630为显示器、鼠标、键盘、触摸屏等输入输出设备提供连接接口。网络接口640为各种联网设备提供连接接口,例如可以连接到数据库服务器或者云端存储服务器等。存储接口650为SD卡、U盘等外置存储设备提供连接接口。
本领域内的技术人员应当明白,本公开的实施例可提供为方法、系统、或计算机程序产品。因此,本公开可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本公开可采用在一个或多个其中包含有计算机可用程序代码的计算机可用非瞬时性存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本公开是参照根据本公开实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解为可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
以上所述仅为本公开的较佳实施例,并不用以限制本公开,凡在本公开的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本公开的保护范围之内。

Claims (12)

1.一种深度学习模型的训练方法,包括:
将训练样本输入待训练的深度学习模型,得到输出的训练样本的特征、各个训练样本所属的类别和各个类别的类中心,其中,所述训练样本包括:锚样本、正样本和负样本,所述深度学习模型为图片识别模型,所述训练样本为图片;
根据输出的训练样本的特征与对应的类中心的特征的距离,锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定损失函数值;
根据所述损失函数值对所述待训练的深度学习模型的参数进行调整,以便完成对所述待训练的深度学习模型的训练;
将待识别图片输入所述深度学习模型,确定待识别图片的类别;
其中,确定损失函数值的方法包括:
根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;
根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;
根据输出的类中心的特征与训练样本中心的特征的距离,确定第三损失函数值;
将所述第一损失函数值、所述第二损失函数值与所述第三损失函数值的加权和确定为损失函数值,其中,所述第一损失函数值对应的权重随着训练周期数量的增加而减小,或者所述第三损失函数值对应的权重随着训练周期数量的增加而减小。
2.根据权利要求1所述的深度学习模型的训练方法,其中,
确定损失函数值的方法包括:
根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;
根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;
将所述第一损失函数值与所述第二损失函数值的加权和确定为损失函数值。
3.根据权利要求2所述的深度学习模型的训练方法,其中,
确定损失函数值的方法还包括:
计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;
在所述差值超过预设范围内的情况下,将所述第一损失函数值与所述第二损失函数值的加权和确定为当前训练周期的损失函数值;
或者,在所述差值在预设范围内的情况下,将所述第二损失函数值确定为当前训练周期的损失函数值。
4.根据权利要求1所述的深度学习模型的训练方法,其中,
确定损失函数值的方法还包括:
计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;
在所述差值超过预设范围内的情况下,将所述第一损失函数值、所述第二损失函数值与所述第三损失函数值的加权和确定为当前训练周期的损失函数值;
或者,在所述差值在预设范围内的情况下,将所述第二损失函数值确定为当前训练周期的损失函数值。
5.根据权利要求1所述的深度学习模型的训练方法,其中,
所述第一损失函数值采用以下公式确定:
其中,m表示输入的训练样本的数量,i表示训练样本的编号,1≤i≤m,i为正整数,g(xi)表示第i个训练样本xi的特征,cyi表示第i个训练样本xi对应的类别yi的类中心的特征;
或者,所述第二损失函数值采用以下公式确定:
其中,j表示训练样本三元组的编号,每个三元组中包含一个锚样本,一个正样本和一个负样本,表示第j个三元组中锚样本的特征,/>第j个三元组中正样本的特征,表示第j个三元组中负样本的特征,α为常数;
或者,所述第三损失函数值采用以下公式确定:
其中,cbc表示训练样本中心的特征。
6.一种深度学习模型的训练装置,包括:
输入模块,用于将训练样本输入待训练的深度学习模型,得到输出的训练样本的特征、各个训练样本所属的类别和各个类别的类中心,其中,所述训练样本包括:锚样本、正样本和负样本,所述深度学习模型为图片识别模型,所述训练样本为图片;
损失函数确定模块,用于根据输出的训练样本的特征与对应的类中心的特征的距离,锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定损失函数值;
调整模块,用于根据所述损失函数值对所述待训练的深度学习模型的参数进行调整,以便完成对所述待训练的深度学习模型的训练;
识别模块,用于将待识别图片输入所述深度学习模型,确定待识别图片的类别;
其中,所述损失函数确定模块用于根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;根据输出的类中心的特征与训练样本中心的特征的距离,确定第三损失函数值;将所述第一损失函数值、所述第二损失函数值与所述第三损失函数值的加权和确定为损失函数值,所述第一损失函数值对应的权重随着训练周期数量的增加而减小,或者所述第三损失函数值对应的权重随着训练周期数量的增加而减小。
7.根据权利要求6所述的深度学习模型的训练装置,其中,
所述损失函数确定模块用于根据输出的训练样本的特征与对应的类中心的特征的距离确定第一损失函数值;根据输出的锚样本的特征与对应的正样本的特征的距离,以及锚样本的特征与对应的负样本的特征的距离,确定第二损失函数值;将所述第一损失函数值与所述第二损失函数值的加权和确定为损失函数值。
8.根据权利要求7所述的深度学习模型的训练装置,其中,
所述损失函数确定模块用于计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;在所述差值超过预设范围内的情况下,将所述第一损失函数值与所述第二损失函数值的加权和确定为当前训练周期的损失函数值;或者,在所述差值在预设范围内的情况下,将所述第二损失函数值确定为当前训练周期的损失函数值。
9.根据权利要求6所述的深度学习模型的训练装置,其中,
所述损失函数确定模块用于计算当前训练周期的第一损失函数值与上一训练周期的第一损失函数值的差值;在所述差值超过预设范围内的情况下,将所述第一损失函数值、所述第二损失函数值与所述第三损失函数值的加权和确定为当前训练周期的损失函数值;或者,在所述差值在预设范围内的情况下,将所述第二损失函数值确定为当前训练周期的损失函数值。
10.根据权利要求6所述的深度学习模型的训练装置,其中,
所述第一损失函数值采用以下公式确定:
其中,m表示输入的训练样本的数量,i表示训练样本的编号,1≤i≤m,i为正整数,g(xi)表示第i个训练样本xi的特征,cyi表示第i个训练样本xi对应的类别yi的类中心的特征;
或者,所述第二损失函数值采用以下公式确定:
其中,j表示训练样本三元组的编号,每个三元组中包含一个锚样本,一个正样本和一个负样本,表示第j个三元组中锚样本的特征,/>第j个三元组中正样本的特征,表示第j个三元组中负样本的特征,α为常数;
或者,所述第三损失函数值采用以下公式确定:
其中,cbc表示训练样本中心的特征。
11.一种深度学习模型的训练装置,包括:
存储器;以及
耦接至所述存储器的处理器,所述处理器被配置为基于存储在所述存储器中的指令,执行如权利要求1-5任一项所述的深度学习模型的训练方法。
12.一种计算机可读存储介质,其上存储有计算机程序,其中,该程序被处理器执行时实现权利要求1-5任一项所述方法的步骤。
CN201811521621.9A 2018-12-13 2018-12-13 深度学习模型的训练方法、装置和计算机可读存储介质 Active CN111325223B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201811521621.9A CN111325223B (zh) 2018-12-13 2018-12-13 深度学习模型的训练方法、装置和计算机可读存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201811521621.9A CN111325223B (zh) 2018-12-13 2018-12-13 深度学习模型的训练方法、装置和计算机可读存储介质

Publications (2)

Publication Number Publication Date
CN111325223A CN111325223A (zh) 2020-06-23
CN111325223B true CN111325223B (zh) 2023-10-24

Family

ID=71168605

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201811521621.9A Active CN111325223B (zh) 2018-12-13 2018-12-13 深度学习模型的训练方法、装置和计算机可读存储介质

Country Status (1)

Country Link
CN (1) CN111325223B (zh)

Families Citing this family (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111914761A (zh) * 2020-08-04 2020-11-10 南京华图信息技术有限公司 一种热红外人脸识别的方法及系统
CN112949384B (zh) * 2021-01-23 2024-03-08 西北工业大学 一种基于对抗性特征提取的遥感图像场景分类方法
CN113033622B (zh) * 2021-03-05 2023-02-03 北京百度网讯科技有限公司 跨模态检索模型的训练方法、装置、设备和存储介质
CN113420121B (zh) * 2021-06-24 2023-07-28 中国科学院声学研究所 文本处理模型训练方法、语音文本处理方法及装置
CN113408299B (zh) * 2021-06-30 2022-03-25 北京百度网讯科技有限公司 语义表示模型的训练方法、装置、设备和存储介质
CN113705111B (zh) * 2021-09-22 2024-04-26 百安居网络技术(上海)有限公司 一种基于深度学习的装修家具自动布局方法及系统

Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106897390A (zh) * 2017-01-24 2017-06-27 北京大学 基于深度度量学习的目标精确检索方法
CN108009528A (zh) * 2017-12-26 2018-05-08 广州广电运通金融电子股份有限公司 基于Triplet Loss的人脸认证方法、装置、计算机设备和存储介质
CN108182394A (zh) * 2017-12-22 2018-06-19 浙江大华技术股份有限公司 卷积神经网络的训练方法、人脸识别方法及装置
WO2018107760A1 (zh) * 2016-12-16 2018-06-21 北京大学深圳研究生院 一种用于行人检测的协同式深度网络模型方法
CN108197538A (zh) * 2017-12-21 2018-06-22 浙江银江研究院有限公司 一种基于局部特征和深度学习的卡口车辆检索系统及方法
CN108734193A (zh) * 2018-03-27 2018-11-02 合肥麟图信息科技有限公司 一种深度学习模型的训练方法及装置
WO2018219016A1 (zh) * 2017-06-02 2018-12-06 腾讯科技(深圳)有限公司 一种人脸检测训练方法、装置及电子设备
JP2019509551A (ja) * 2016-02-04 2019-04-04 エヌイーシー ラボラトリーズ アメリカ インクNEC Laboratories America, Inc. Nペア損失による距離計量学習の改善
CN111753583A (zh) * 2019-03-28 2020-10-09 阿里巴巴集团控股有限公司 一种识别方法及装置
KR20200135730A (ko) * 2019-05-22 2020-12-03 한국전자통신연구원 이미지 딥러닝 모델 학습 방법 및 장치
CN115134153A (zh) * 2022-06-30 2022-09-30 中国电信股份有限公司 安全评估方法、装置和模型训练方法、装置
CN115641613A (zh) * 2022-11-03 2023-01-24 西安电子科技大学 一种基于聚类和多尺度学习的无监督跨域行人重识别方法

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US9390370B2 (en) * 2012-08-28 2016-07-12 International Business Machines Corporation Training deep neural network acoustic models using distributed hessian-free optimization
US10270788B2 (en) * 2016-06-06 2019-04-23 Netskope, Inc. Machine learning based anomaly detection

Patent Citations (13)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP2019509551A (ja) * 2016-02-04 2019-04-04 エヌイーシー ラボラトリーズ アメリカ インクNEC Laboratories America, Inc. Nペア損失による距離計量学習の改善
WO2018107760A1 (zh) * 2016-12-16 2018-06-21 北京大学深圳研究生院 一种用于行人检测的协同式深度网络模型方法
WO2018137358A1 (zh) * 2017-01-24 2018-08-02 北京大学 基于深度度量学习的目标精确检索方法
CN106897390A (zh) * 2017-01-24 2017-06-27 北京大学 基于深度度量学习的目标精确检索方法
WO2018219016A1 (zh) * 2017-06-02 2018-12-06 腾讯科技(深圳)有限公司 一种人脸检测训练方法、装置及电子设备
CN108197538A (zh) * 2017-12-21 2018-06-22 浙江银江研究院有限公司 一种基于局部特征和深度学习的卡口车辆检索系统及方法
CN108182394A (zh) * 2017-12-22 2018-06-19 浙江大华技术股份有限公司 卷积神经网络的训练方法、人脸识别方法及装置
CN108009528A (zh) * 2017-12-26 2018-05-08 广州广电运通金融电子股份有限公司 基于Triplet Loss的人脸认证方法、装置、计算机设备和存储介质
CN108734193A (zh) * 2018-03-27 2018-11-02 合肥麟图信息科技有限公司 一种深度学习模型的训练方法及装置
CN111753583A (zh) * 2019-03-28 2020-10-09 阿里巴巴集团控股有限公司 一种识别方法及装置
KR20200135730A (ko) * 2019-05-22 2020-12-03 한국전자통신연구원 이미지 딥러닝 모델 학습 방법 및 장치
CN115134153A (zh) * 2022-06-30 2022-09-30 中国电信股份有限公司 安全评估方法、装置和模型训练方法、装置
CN115641613A (zh) * 2022-11-03 2023-01-24 西安电子科技大学 一种基于聚类和多尺度学习的无监督跨域行人重识别方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
基于多辅助分支深度网络的行人再识别;夏开国;田畅;;通信技术(第11期);2601-2605 *
机器学习在数据挖掘中的应用;王泓正;;中国新技术新产品(第22期);98-99 *

Also Published As

Publication number Publication date
CN111325223A (zh) 2020-06-23

Similar Documents

Publication Publication Date Title
CN111325223B (zh) 深度学习模型的训练方法、装置和计算机可读存储介质
KR102170105B1 (ko) 신경 네트워크 구조의 생성 방법 및 장치, 전자 기기, 저장 매체
US20230252327A1 (en) Neural architecture search for convolutional neural networks
CN108664893B (zh) 一种人脸检测方法及存储介质
CN108090470B (zh) 一种人脸对齐方法及装置
CN111950723B (zh) 神经网络模型训练方法、图像处理方法、装置及终端设备
CN110689136B (zh) 一种深度学习模型获得方法、装置、设备及存储介质
KR20200049422A (ko) 시뮬레이션-가이드된 반복적 프루닝을 사용하는 효율적인 네트워크 압축
CN115393633A (zh) 数据处理方法、电子设备、存储介质及程序产品
CN110489131B (zh) 一种灰度用户选取方法及装置
CN113965313A (zh) 基于同态加密的模型训练方法、装置、设备以及存储介质
CN110826695B (zh) 数据处理方法、装置和计算机可读存储介质
CN109783769B (zh) 一种基于用户项目评分的矩阵分解方法和装置
CN111126456A (zh) 神经网络模型的处理方法、装置、设备及存储介质
CN110457155A (zh) 一种样本类别标签的修正方法、装置及电子设备
US20210397962A1 (en) Effective network compression using simulation-guided iterative pruning
CN110610140A (zh) 人脸识别模型的训练方法、装置、设备及可读存储介质
CN112598078B (zh) 混合精度训练方法、装置、电子设备及存储介质
CN113361381B (zh) 人体关键点检测模型训练方法、检测方法及装置
CN111291464A (zh) 一种电力系统动态等值方法及装置
CN113112092A (zh) 一种短期概率密度负荷预测方法、装置、设备和存储介质
CN112766403A (zh) 一种基于信息增益权重的增量聚类方法及装置
CN117435308B (zh) 一种基于并行计算算法的Modelica模型仿真方法及系统
CN116629388B (zh) 差分隐私联邦学习训练方法、装置和计算机可读存储介质
CN110033098A (zh) 在线gbdt模型学习方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant