CN113065593A - 模型训练方法、装置、计算机设备和存储介质 - Google Patents
模型训练方法、装置、计算机设备和存储介质 Download PDFInfo
- Publication number
- CN113065593A CN113065593A CN202110355199.XA CN202110355199A CN113065593A CN 113065593 A CN113065593 A CN 113065593A CN 202110355199 A CN202110355199 A CN 202110355199A CN 113065593 A CN113065593 A CN 113065593A
- Authority
- CN
- China
- Prior art keywords
- loss value
- image
- value
- distance value
- image processing
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本申请涉及一种模型训练方法、装置、计算机设备和存储介质。所述方法包括:获取三元组图像样本;将所述三元组图像样本输入至图像处理模型中,获得参考损失值;当所述参考损失值满足预设损失值条件时,在所述参考损失值的基础上加入与所述预设损失值条件对应的调整项,获得目标损失值;基于所述目标损失值对所述图像处理模型中的参数进行调整,直至所述图像处理模型满足收敛条件,获得训练完成的图像处理模型。采用本方法能够提高训练得到的模型的性能。
Description
技术领域
本申请涉及机器学习技术领域,特别是涉及一种模型训练方法、装置、计算机设备和存储介质。
背景技术
大多数现代机器学习算法(例如深度神经网络)都需要对数百种算法进行训练。而在某些实际问题中,往往无法提供大量示例。少样本学习方法通过使用少量示例来训练计算模型,试图解决这些问题。
这些算法只需要几个每个类的训练示例,无需大量反复训练即可将模型推广到不熟悉的类别。为了解决少数问题,已经采用了诸如度量学习算法之类的方法。度量学习算法采用一些距离度量来发现图像之间的相似性。
深度度量算法的最重要元素是损失函数。最受欢迎的损失函数之一是三元组损失函数。然而传统的通过三元组损失函数所训练得到的图像处理模型的性能较差。
发明内容
基于此,有必要针对上述技术问题,提供一种模型训练方法、装置、计算机设备和存储介质。
一种模型训练方法,所述方法包括:
获取三元组图像样本;
将所述三元组图像样本输入至图像处理模型中,获得参考损失值;
当所述参考损失值满足预设损失值条件时,在所述参考损失值的基础上加入与所述预设损失值条件对应的调整项,获得目标损失值;
基于所述目标损失值对所述图像处理模型中的参数进行调整,直至所述图像处理模型满足收敛条件,获得训练完成的图像处理模型。
一种模型训练装置,所述装置包括:
图像样本获取模块,用于获取三元组图像样本;
目标损失值获得模块,用于将所述三元组图像样本输入至图像处理模型中,获得参考损失值;
所述目标损失值获得模块,还用于当所述参考损失值满足预设损失值条件时,在所述参考损失值的基础上加入与所述预设损失值条件对应的调整项,获得目标损失值;
参数调整模块,用于基于所述目标损失值对所述图像处理模型中的参数进行调整,直至所述图像处理模型满足收敛条件,获得训练完成的图像处理模型。
一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:
获取三元组图像样本;
将所述三元组图像样本输入至图像处理模型中,获得参考损失值;
当所述参考损失值满足预设损失值条件时,在所述参考损失值的基础上加入与所述预设损失值条件对应的调整项,获得目标损失值;
基于所述目标损失值对所述图像处理模型中的参数进行调整,直至所述图像处理模型满足收敛条件,获得训练完成的图像处理模型。
一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现以下步骤:
获取三元组图像样本;
将所述三元组图像样本输入至图像处理模型中,获得参考损失值;
当所述参考损失值满足预设损失值条件时,在所述参考损失值的基础上加入与所述预设损失值条件对应的调整项,获得目标损失值;
基于所述目标损失值对所述图像处理模型中的参数进行调整,直至所述图像处理模型满足收敛条件,获得训练完成的图像处理模型。
上述模型训练方法、装置、计算机设备和存储介质,获取三元组图像样本,将三元组图像样本输入至图像处理模型中,获得参考损失值,当参考损失值满足预设损失值条件时,在参考损失值的基础上加入与预设损失值条件对应的调整项,获得目标损失值,能够通过更合理的损失值减少图像处理模型的训练时长,也能够使得模型训练不容易崩溃,能够使得训练完成的图像处理模型具有更好的性能。
附图说明
图1为一个实施例中模型训练方法的流程示意图;
图2为一个实施例中训练前的图像处理模型的ROC曲线;
图3为一个实施例中训练完成的图像处理模型的ROC曲线;
图4为一个实施例中未迭代的图像处理模型的映射空间类间距的结果示意图;
图5为一个实施例中迭代了100次的视频图像处理模型映射空间类间距的结果示意图;
图6为一个实施例中模型训练装置的结构框图;
图7为一个实施例中计算机设备的内部结构图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
本申请提供的模型训练方法,可以应用于计算机设备中。其中,计算机设备可以为终端或者服务器。其中,终端可以但不限于是各种个人计算机、笔记本电脑、智能手机、平板电脑和便携式可穿戴设备。服务器可以用独立的服务器或者是多个服务器组成的服务器集群来实现。
在一个实施例中,如图1所示,提供了一种模型训练方法,以该方法应用于计算机设备为例进行说明,包括以下步骤:
步骤102,获取三元组图像样本。
其中,三元组图像样本中包括3个图像样本,其中包括锚点图像样本(Anchor)、正图像样本(Positive)和负图像样本(Negative)。随机选取的锚点图像样本,与锚点图像样本属于相同类型的正图像样本,与锚点图像样本属于不同类的负图像样本。
具体地,计算机设备读取数据集、预处理数据集中的图像,并分割出训练集和测试集。三元组图像样本是训练集中的图像样本。
步骤104,将三元组图像样本输入至图像处理模型中,获得参考损失值。
其中,图像处理模型可以是指图像识别模型,也可以是用于协同图像分割等的模型。图像处理模型具体可以是预训练图像处理模型。
具体地,计算机设备将三元组图像样本输入至图像处理模型中,通过图像处模型获得该三元组图像样本对应的参考损失值。参考损失值可以是通过欧几里得距离或余弦距离实现。
其中,dis表示距离,τz表示第z个三元组图像样本,a表示锚点图像样本,p表示正图像样本,n表示负图像样本,m是一个常数。
步骤106,当参考损失值满足预设损失值条件时,在参考损失值的基础上加入与预设损失值条件对应的调整项,获得目标损失值。
其中,预设损失值条件用于判断参考损失值的大小是否满足预设条件。
具体地,当参考损失值满足预设损失值条件时,在参考损失值的基础上加入与预设损失值条件对应的调整项,获得目标损失值。
例如,预设损失值条件为参考损失值在预设范围内,加入的对应调整项为加上惩罚项。预设损失值条件为参考损失值大于预设距离阈值,加入的对应调整项为减去奖赏项。
步骤108,基于目标损失值对图像处理模型中的参数进行调整,直至图像处理模型满足收敛条件,获得训练完成的图像处理模型。
其中,收敛条件具体可以是训练次数达到预设次数、评估参数达到预设评估值等不限于此。
具体地,计算机设备基于目标损失值对图像处理模型中的参数进行调整,直至图像处理模型满足收敛条件时,获得训练完成的图像处理模型。
上述模型训练方法,获取三元组图像样本,将三元组图像样本输入至图像处理模型中,获得参考损失值,当参考损失值满足预设损失值条件时,在参考损失值的基础上加入与预设损失值条件对应的调整项,获得目标损失值,能够通过更合理的损失值减少图像处理模型的训练时长,也能够使得模型训练不容易崩溃,通过少量样本就能够将图像处理模型训练成功,还能够使得训练完成的图像处理模型具有更好的性能。
在一个实施例中,三元组图像样本包括锚点图像样本、正图像样本和负图像样本;参考损失值是基于第一距离值与第二距离值之差计算得到的。
当参考损失值满足预设损失值条件时,在参考损失值中加入与预设损失值条件对应的调整项,获得目标损失值,包括:
当第一距离值与第二距离值之差在预设范围内时,在参考损失值的基础上加上惩罚项,获得目标损失值;第一距离值是锚点图像样本与正图像样本之间的距离值;第二距离值是锚点图像样本与负图像样本之间的距离值。
其中,距离值可以是欧几里得距离值,也可以是余弦距离值等不限于此。第一距离值是锚点图像样本与正图像样本之间的距离值。预设范围是一个设定的数值范围。第一距离值与第二距离值之差在预设范围内可以是(0,x),x是一个设定的阈值。并且x可以设置为较小的数值,例如0.01、0.1等不限于此。
具体地,当第一距离值与第二距离值之差在预设范围内时,说明该损失值较小,是一个较好的结果,因此在损失值的基础上减去奖赏项,获得目标损失值。例如,参考损失值为loss,目标损失值为Tloss,奖赏项为reward,第一距离值为dis1,第二距离值为dis2,预设范围内是(0,x)那么当x>dis1-dis2>0时,Tloss=loss-reward。
本实施例中,第一距离值是锚点图像样本与正图像样本之间的距离值,第二距离值是锚点图像样本与负图像样本之间的距离值,当第一距离值与第二距离值之差在预设范围内时,说明该损失值较小,该次训练效果较好,因此在参考损失值的基础上减去奖赏项,获得目标损失值,从而使得训练出的图像处理模型性能更佳。
在一个实施例中,奖赏项的获取方式,包括:根据第一距离值与第二距离值之和,与第一参数的乘积,获得奖赏项。
其中,第一参数可以根据需求设置。
具体地,当第一距离值与第二距离值之差在预设范围内时,称之为最好情况。最好情况满足:
这个distance比0大但是又不会大很多。我们假设:
其中,ε=km,0<k<1
于是有:
因此,在参考损失值的基础上减去奖赏项,为:
上式中的loss(τBz)即为目标损失值,loss为参考损失值,reward即为奖赏项。dis表示距离,τz表示第z个三元组图像样本,a表示锚点图像样本,p表示正图像样本,n表示负图像样本,m是一个常数。上式中的a/2可视为第一参数。
本实施例中,根据第一距离值与第二距离值之差,与第一参数的乘积,获得奖赏项,能够使得训练出的模型性能更好。
在一个实施例中,三元组图像样本包括锚点图像样本、正图像样本和负图像样本;参考损失值是基于第一距离值与第二距离值之差计算得到的。
当参考损失值满足预设损失值条件时,在参考损失值中加入与预设损失值条件对应的调整项,获得目标损失值,包括:
当第一距离值与第二距离值之差大于预设距离阈值时,在参考损失值的基础上加上惩罚项,获得目标损失值;第一距离值是锚点图像样本与正图像样本之间的距离值;第二距离值是锚点图像样本与负图像样本之间的距离值。
具体地,当第一距离值与第二距离值只差大于预设距离值时,说明参考损失值较大,需要增加惩罚项以获得目标损失值。
例如,参考损失值为loss,目标损失值为Tloss,惩罚项为penalty,第一距离值为dis1,第二距离值为dis2,预设距离阈值为y,那么当dis1-dis2>y时,Tloss=loss+penalty。
本实施例中,第一距离值是锚点图像样本与正图像样本之间的距离值,第二距离值是锚点图像样本与负图像样本之间的距离值,当第一距离值与第二距离值之差大于预设距离阈值时,说明该损失值较大,该次训练效果较差,因此在参考损失值的基础上加上惩罚项,获得目标损失值,从而使得训练出的图像处理模型性能更佳。
在一个实施例中,惩罚项的获取方式,包括:根据第一距离值与第二距离值之和,与第二参数的乘积,获得惩罚项。
其中,第二参数可以根据需要调整。
具体地,计算机设备根据第一距离值与第二距离值之和,与第二参数的乘积,获得惩罚项。
例如,当第一距离值与第二距离值之差大于预设距离阈值m时,有
那么目标损失值则有
上式中的loss(τBz)即为目标损失值,loss为参考损失值,reward即为奖赏项。dis表示距离,τz表示第z个三元组图像样本,a表示锚点图像样本,p表示正图像样本,n表示负图像样本,m是一个常数。上式中的a/2可视为第二参数。
本实施例中,根据第一距离值与第二距离值之和,与第二参数的乘积,获得奖赏项,能够使得训练出的模型性能更好。
在一个实施例中,获取三元组图像样本,包括:
从第一类的图像集中获取参考锚点图像样本以及对应的参考正图像样本,从第二类的图像集中获取参考负图像样本;第一类和第二类不同;
将参考锚点图像样本、参考正图像样本和参考负图像样本作为候选三元组图像样本;
从候选三元组图像样本中,确定第一距离值与第二距离值之差大于预设距离阈值的三元组图像样本;
当参考损失值满足预设损失值条件时,在参考损失值的基础上加入与预设损失值条件对应的调整项,获得目标损失值,包括:
在参考损失值的基础上加上惩罚项,获得目标损失值。
具体地,计算机设备从第一类的图像集中获取参考锚点图像样本以及对应的参考正图像样本。即参考锚点图像样本和参考正图像样本属于同一类的图像样本。计算机设备从第二类的图像集中获取对应的参考负图像样本。例如第一类的图像集为A类的图像集,第二类的图像集为B类的图像集。将参考锚点图像样本、参考正图像样本和参考负图像样本作为候选三元组图像样本。
计算机设备计算各三元组图像样本中每个图像样本的第一距离值和第二距离值,并从候选三元组图像样本中,确定第一距离值与第二距离值之差大于预设距离阈值的三元组图像样本。那么该三元组图像样本的损失值则满足第一距离值与第二距离值之差大于预设距离阈值的条件,因此在参考损失值的基础上加上惩罚项,获得目标损失值。
本实施例中,从第一类的图像集中获取参考锚点图像样本以及对应的参考正图像样本,从第二类的图像集中获取参考负图像样本,第一类和第二类不同,将参考锚点图像样本、参考正图像样本和参考负图像样本作为候选三元组图像样本,从候选三元组图像样本中,确定第一距离值与第二距离值之差大于预设距离阈值的三元组图像样本,能够在随机选取策略的基础上,增加一次筛选,选出候选三元组图像样本中表现较差的样本,将选出的样本作为模型训练数据,能够加快模型训练,并且通过少量样本即可将模型训练好。
在一个实施例中,该模型训练方法还包括:
获取三元组图像测试样本;
将三元组图像测试样本输入至训练完成的图像处理模型,获得图像处理结果;
基于图像处理结果生成对训练完成的图像处理模型的评估数据;
将评估数据以图表的方式呈现。
其中,当图像处理模型是图像识别模型时,图像处理结果即为图像识别结果。当图像处理模型是图像分割模型时,图像处理结果即为图像分割结果。
具体地,三元组图像测试样本用于检测训练完成的图像处理模型的模型性能。计算机设备获取三元组图像测试样本,将三元组图像测试样本输入至训练完成的图像处理模型,获得图像处理结果。
计算机设备基于图像处理结果生成对训练完成的图像处理模型的评估数据。计算机设备将评估数据以图标的方式呈现在终端。
本实施例中,获取三元组图像测试样本,将三元组图像测试样本输入至训练完成的图像处理模型,获得图像处理结果,基于图像处理结果生成对训练完成的图像处理模型的评估数据,并将评估数据以图表方式呈现,能够更加直观地展示模型训练效果。
在一个实施例中,评估数据可以包括ROC(receiver operating characteristiccurve,接受者操作特征曲线)、AUC(Area Under Curve)值被定义为ROC曲线下与坐标轴围成的面积、召回率(recall)、灵敏度(sensitive)等。
其中ROC曲线的ROC曲线的横坐标是false positive rate(FPR),纵坐标是truepositive rate(TPR)。
TPR:在所有实际为阳性的样本中,被正确地判断为阳性之比率。TPR=TP/(TP+FN)
FPR:在所有实际为阴性的样本中,被错误地判断为阳性之比率。FPR=FP/(FP+TN)
阳性(P,positive)阴性(N,Negative)
真阳性(TP,true positive)正确的肯定。又称:命中(hit)。
真阴性(TN,true negative)正确的否定。又称:正确拒绝(correct rejection)。
伪阳性(FP,false positive)错误的肯定,又称:假警报(false alarm),第一型错误。
伪阴性(FN,false negative)错误的否定,又称:未命中(miss),第二型错误。
以一个表格表示:
表1
AUC值的定义
AUC值为ROC曲线所覆盖的区域面积,显然,AUC越大,分类器分类效果越好。
AUC=1,是完美分类器,采用这个预测模型时,不管设定什么阈值都能得出完美预测。绝大多数预测的场合,不存在完美分类器。
0.5<AUC<1,优于随机猜测。这个分类器(模型)妥善设定阈值的话,能有预测价值。
AUC=0.5,跟随机猜测一样(例:丢铜板),模型没有预测价值。
AUC<0.5,比随机猜测还差;但只要总是反预测而行,就优于随机猜测。
召回率是覆盖面的度量,度量有多个正例被分为正例,recall=TP/(TP+FN)=TP/P=sensitive。
如图2所示,为一个实施例中训练前的图像处理模型的ROC曲线。图中ROC曲线所围成的面积即AUC值=0.707,sensitivity(recall)=1.6%,@FPR=1e-03,threshold=0.030945874750614166。其中该threshold用于表示分类的正误。
如图3所示,为一个实施例中训练完成的图像处理模型的ROC曲线。图中ROC曲线所围成的面积即AUC值=0.929,sensitivity(recall)=7.9%,@FPR=1e-03,threshold=0.05653351545333862。由此可知,训练完成的图像处理模型的分类效果比未训练的图像处理模型更好。
如图4所示,为一个实施例中未迭代的图像处理模型的映射空间类间距的结果示意图。图4中横坐标表示类型(classes),图中选取了0~9类。纵坐标表示映射空间类间距(distance)。理论上来说,映射空间类间距越大,表示图像识别效果或者说图像分类效果越好。
如图5所示,为一个实施例中迭代了100次的视频图像处理模型映射空间类间距的结果示意图。图5中横坐标表示类型(classes),图中选取了0~9类。纵坐标表示映射空间类间距(distance)。理论上来说,映射空间类间距越大,表示图像识别效果越好。相比于图4,很显然图5中的图像识别效果比图4中的更好。
用于图像识别时,与传统的三元组损失函数效果对比:(表2)
方式 | AUC | Recall(sensitivity) |
标准三元组损失函数 | 0.997 | 77.7% |
加入惩罚项 | 0.998 | 81.5% |
加入惩罚项和奖赏项 | 0.998 | 87.2% |
用于图像协同分割时,与传统的三元组损失函数效果对比:(表3)
方式 | Precision(精确度) | Jaccard index |
余弦距离损失值 | 68.2 | 0.57 |
标准三元组损失函数 | 69.06 | 0.61 |
加入惩罚项 | 72.8 | 0.65 |
加入惩罚项和奖赏项 | 73.2 | 0.66 |
Jaccard index,又称为Jaccard相似系数(Jaccard similarity coefficient)用于比较有限样本集之间的相似性与差异性。Jaccard系数值越大,样本相似度越高。
表2和表3中传统三元组损失函数中不含奖赏项和惩罚项。表中的加入惩罚项是指当第一距离值与第二距离值之差大于预设距离阈值时,在参考损失值的基础上加上惩罚项,获得目标损失值。表中的加入奖赏项是指当第一距离值与第二距离值之差在预设范围内时,在参考损失值的基础上减去奖赏项,获得目标损失值。加入惩罚项和加入奖赏项即为当第一距离值与第二距离值之差大于预设距离阈值时,在参考损失值的基础上加上惩罚项;且当第一距离值与第二距离值之差在预设范围内时,在参考损失值的基础上减去奖赏项。
显然,加入本申请各实施例中的奖赏项和惩罚项能够使得模型的各种性能大大提高。
在一个实施例中,本申请各个实施例中的主要目标是解决标准三元组损失的限制,并提出一个有条件的损失函数。我们通过对三元组最坏的情况进行惩罚和对三元组最好的情况给予奖励来改进标准的三元组损失。我们将所提出的条件三元组损失函数用于图像识别和图像协同分割任务上。
在随机采样的三元组中,直接放入神经网络模型训练并不能体现各个三元组的差异性(标准三元组损失不考虑这个)。比如最好的用于网络训练的数据是刚好使得以上损失大于0(小于0的不能用于网络训练,没有意义),这样的训练数据往往提供给模型训练的梯度小,使得模型训练慢。另外对于最坏的情况是锚点正样本对的距离比锚点负样本对的距离大很多,这样的数据往往提供很大的训练梯度差,能加快模型训练,但由于提供的这个梯度差可能过大,容易使得整个模型陷入局部最优甚至崩溃,所以对这样最坏的三元组数据我们对齐做出惩罚。
假设原始的每个三元组产生的损失为标准损失loss,那么对最坏的情况我们增加惩罚项penalty,即损失变为loss+penalty;对最好的三元素数据则变为loss-reward,其他情况loss不变。通过上述方式可以通过少量样本即可将模型训练成功。
在一个实施例中,该模型训练方法还包括:当参考损失值不满足预设损失值条件时,将参考损失值作为目标损失值。由于有些三元组样本比较寻常,不算太差也不算太好,对于这类样本不做特殊处理,同样能够达到表2和表3的效果。
在一个实施例中,软件主要由以下几个模块组成:数据集导入和预处理模块,三元组提取模块,神经网络模块,性能评估模块,可视化模块。核心模块为神经网路模块,主要部分是用于神经网络训练的三元组损失函数的设计优化,其中神经网络结构本身为可代替模块。目前训练用数据集为手写数字识别MNIST数据集。
这五个模块之间的逻辑为首先数据集模块读取并预处理数据集,已备后边使用,三元组模块则使用提取出的数据集中的训练集提取直接用于网络训练的三元组,神经网络模块则是定义神经网络结构,定义损失函数,性能评估模块则是使用提取好的数据集,训练神经网络,得到最优的网络模型,可视化模块则是将训练好的模型应用于测试数据集中的结果以图形图表的形式展现。主要涉及的函数有:
DrawPics(),drawTriplets(),draw_roc(),draw_interdist(),DrawTestImage()等。
其中,定义用于图像特征向量提取的神经网络,由于整个代码核心不是在网络结构上,故定义一个简单的网络结构作为基础。网络输入大小为28x28,MINIST为单通道图像,故读取的图片张量为(28,28,1),网络结构如下图:第一个卷积层为128个7x7卷积核,最大池化层。第二个卷积层为128个3x3卷积核,最大池化层,256个3x3卷积核,此时输出的特征图为256层的2x2特征,经过一个flatten层,变成2x2x256=1024的特征向量,接着两层全连接层1024->4096->10,最后对每个得到的10维的向量做L2正则化,使其变为单位向量。
应该理解的是,虽然图1的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,图1中的至少一部分步骤可以包括多个步骤或者多个阶段,这些步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤中的步骤或者阶段的至少一部分轮流或者交替地执行。
在一个实施例中,如图6所示,提供了一种模型训练装置,包括:图像样本获取模块602、目标损失值获得模块604、和参数调整模块606,其中:
图像样本获取模块602,用于获取三元组图像样本;
目标损失值获得模块604,用于将三元组图像样本输入至图像处理模型中,获得参考损失值;
目标损失值获得模块604,还用于当参考损失值满足预设损失值条件时,在参考损失值的基础上加入与预设损失值条件对应的调整项,获得目标损失值;
参数调整模块606,用于基于目标损失值对图像处理模型中的参数进行调整,直至图像处理模型满足收敛条件,获得训练完成的图像处理模型。
上述模型训练装置,获取三元组图像样本,将三元组图像样本输入至图像处理模型中,获得参考损失值,当参考损失值满足预设损失值条件时,在参考损失值的基础上加入与预设损失值条件对应的调整项,获得目标损失值,能够通过更合理的损失值减少图像处理模型的训练时长,也能够使得模型训练不容易崩溃,能够使得训练完成的图像处理模型具有更好的性能。
在一个实施例中,三元组图像样本包括锚点图像样本、正图像样本和负图像样本;参考损失值是基于第一距离值与第二距离值之差计算得到的。目标损失值获得模块604,用于当第一距离值与第二距离值之差在预设范围内时,在参考损失值的基础上减去奖赏项,获得目标损失值;第一距离值是锚点图像样本与正图像样本之间的距离值;第二距离值是锚点图像样本与负图像样本之间的距离值。
本实施例中,第一距离值是锚点图像样本与正图像样本之间的距离值,第二距离值是锚点图像样本与负图像样本之间的距离值,当第一距离值与第二距离值之差在预设范围内时,说明该损失值较小,该次训练效果较好,因此在参考损失值的基础上减去奖赏项,获得目标损失值,从而使得训练出的图像处理模型性能更佳。
在一个实施例中,目标损失值获得模块604还用于根据第一距离值与第二距离值之和,与第一参数的乘积,获得奖赏项。
本实施例中,根据第一距离值与第二距离值之差,与第一参数的乘积,获得奖赏项,能够使得训练出的模型性能更好。
在一个实施例中,三元组图像样本包括锚点图像样本、正图像样本和负图像样本;参考损失值是基于第一距离值与第二距离值之差计算得到的。目标损失值获得模块604还用于当第一距离值与第二距离值之差大于预设距离阈值时,在参考损失值的基础上加上惩罚项,获得目标损失值;第一距离值是锚点图像样本与正图像样本之间的距离值;第二距离值是锚点图像样本与负图像样本之间的距离值。
本实施例中,第一距离值是锚点图像样本与正图像样本之间的距离值,第二距离值是锚点图像样本与负图像样本之间的距离值,当第一距离值与第二距离值之差大于预设距离阈值时,说明该损失值较大,该次训练效果较差,因此在参考损失值的基础上加上惩罚项,获得目标损失值,从而使得训练出的图像处理模型性能更佳。
在一个实施例中,目标损失值获得模块604还用于根据第一距离值与第二距离值之和,与第二参数的乘积,获得惩罚项。
本实施例中,根据第一距离值与第二距离值之和,与第二参数的乘积,获得奖赏项,能够使得训练出的模型性能更好。
在一个实施例中,图像样本获取模块602用于从第一类的图像集中获取参考锚点图像样本以及对应的参考正图像样本,从第二类的图像集中获取参考负图像样本;第一类和第二类不同;将参考锚点图像样本、参考正图像样本和参考负图像样本作为候选三元组图像样本;从候选三元组图像样本中,确定第一距离值与第二距离值之差大于预设距离阈值的三元组图像样本。目标损失值获得模块604用于在参考损失值的基础上加上惩罚项,获得目标损失值。
本实施例中,从第一类的图像集中获取参考锚点图像样本以及对应的参考正图像样本,从第二类的图像集中获取参考负图像样本,第一类和第二类不同,将参考锚点图像样本、参考正图像样本和参考负图像样本作为候选三元组图像样本,从候选三元组图像样本中,确定第一距离值与第二距离值之差大于预设距离阈值的三元组图像样本,能够在随机选取策略的基础上,增加一次筛选,选出候选三元组图像样本中表现较差的样本,将选出的样本作为模型训练数据,能够加快模型训练。
在一个实施例中,模型训练装置还包括性能评估模块和可视化模块。图像样本获取模块602还用于获取三元组图像测试样本。性能评估模块用于将三元组图像测试样本输入至训练完成的图像处理模型,获得图像处理结果;基于图像处理结果生成对训练完成的图像处理模型的评估数据。可视化模块用于将评估数据以图表的方式呈现。
本实施例中,获取三元组图像测试样本,将三元组图像测试样本输入至训练完成的图像处理模型,获得图像处理结果,基于图像处理结果生成对训练完成的图像处理模型的评估数据,并将评估数据以图表方式呈现,能够更加直观地展示模型训练效果。
关于模型训练装置的具体限定可以参见上文中对于模型训练方法的限定,在此不再赘述。上述模型训练装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是终端,其内部结构图可以如图7所示。该计算机设备包括通过系统总线连接的处理器、存储器、通信接口、显示屏和输入装置。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统和计算机程序。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的通信接口用于与外部的终端进行有线或无线方式的通信,无线方式可通过WIFI、运营商网络、NFC(近场通信)或其他技术实现。该计算机程序被处理器执行时以实现一种模型训练方法。该计算机设备的显示屏可以是液晶显示屏或者电子墨水显示屏,该计算机设备的输入装置可以是显示屏上覆盖的触摸层,也可以是计算机设备外壳上设置的按键、轨迹球或触控板,还可以是外接的键盘、触控板或鼠标等。
本领域技术人员可以理解,图7中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,还提供了一种计算机设备,包括存储器和处理器,存储器中存储有计算机程序,该处理器执行计算机程序时实现上述各方法实施例中的步骤。
在一个实施例中,提供了一种计算机可读存储介质,存储有计算机程序,该计算机程序被处理器执行时实现上述各方法实施例中的步骤。
在一个实施例中,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述各方法实施例中的步骤。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和易失性存储器中的至少一种。非易失性存储器可包括只读存储器(Read-Only Memory,ROM)、磁带、软盘、闪存或光存储器等。易失性存储器可包括随机存取存储器(Random Access Memory,RAM)或外部高速缓冲存储器。作为说明而非局限,RAM可以是多种形式,比如静态随机存取存储器(Static Random Access Memory,SRAM)或动态随机存取存储器(Dynamic Random Access Memory,DRAM)等。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请专利的保护范围应以所附权利要求为准。
Claims (10)
1.一种模型训练方法,其特征在于,所述方法包括:
获取三元组图像样本;
将所述三元组图像样本输入至图像处理模型中,获得参考损失值;
当所述参考损失值满足预设损失值条件时,在所述参考损失值的基础上加入与所述预设损失值条件对应的调整项,获得目标损失值;
基于所述目标损失值对所述图像处理模型中的参数进行调整,直至所述图像处理模型满足收敛条件,获得训练完成的图像处理模型。
2.根据权利要求1所述的方法,其特征在于,所述三元组图像样本包括锚点图像样本、正图像样本和负图像样本;所述参考损失值是基于第一距离值与第二距离值之差计算得到的;
所述当所述参考损失值满足预设损失值条件时,在所述参考损失值中加入与所述预设损失值条件对应的调整项,获得目标损失值,包括:
当第一距离值与第二距离值之差在预设范围内时,在所述参考损失值的基础上减去奖赏项,获得目标损失值;所述第一距离值是所述锚点图像样本与所述正图像样本之间的距离值;所述第二距离值是所述锚点图像样本与所述负图像样本之间的距离值。
3.根据权利要求2所述的方法,其特征在于,所述奖赏项的获取方式,包括:
根据所述第一距离值与所述第二距离值之差,与第一参数的乘积,获得奖赏项。
4.根据权利要求1所述的方法,其特征在于,所述三元组图像样本包括锚点图像样本、正图像样本和负图像样本;所述参考损失值是基于第一距离值与第二距离值之差计算得到的;
所述当所述参考损失值满足预设损失值条件时,在所述参考损失值中加入与所述预设损失值条件对应的调整项,获得目标损失值,包括:
当第一距离值与第二距离值之差大于预设距离阈值时,在所述参考损失值的基础上加上惩罚项,获得目标损失值;所述第一距离值是所述锚点图像样本与所述正图像样本之间的距离值;所述第二距离值是所述锚点图像样本与所述负图像样本之间的距离值。
5.根据权利要求4所述的方法,其特征在于,所述惩罚项的获取方式,包括:
根据所述第一距离值与所述第二距离值之和,与第二参数的乘积,获得惩罚项。
6.根据权利要求1所述的方法,其特征在于,所述获取三元组图像样本,包括:
从第一类的图像集中获取参考锚点图像样本以及对应的参考正图像样本,从第二类的图像集中获取参考负图像样本;所述第一类和所述第二类不同;
将所述参考锚点图像样本、参考正图像样本和参考负图像样本作为候选三元组图像样本;
从候选三元组图像样本中,确定第一距离值与第二距离值之差大于预设距离阈值的三元组图像样本;
所述当所述参考损失值满足预设损失值条件时,在所述参考损失值的基础上加入与所述预设损失值条件对应的调整项,获得目标损失值,包括:
在所述参考损失值的基础上加上惩罚项,获得目标损失值。
7.根据权利要求1至6任一项所述的方法,其特征在于,所述方法还包括:
获取三元组图像测试样本;
将所述三元组图像测试样本输入至所述训练完成的图像处理模型,获得图像处理结果;
基于所述图像处理结果生成对所述训练完成的图像处理模型的评估数据;
将所述评估数据以图表的方式呈现。
8.一种模型训练装置,其特征在于,所述装置包括:
图像样本获取模块,用于获取三元组图像样本;
目标损失值获得模块,用于将所述三元组图像样本输入至图像处理模型中,获得参考损失值;
所述目标损失值获得模块,还用于当所述参考损失值满足预设损失值条件时,在所述参考损失值的基础上加入与所述预设损失值条件对应的调整项,获得目标损失值;
参数调整模块,用于基于所述目标损失值对所述图像处理模型中的参数进行调整,直至所述图像处理模型满足收敛条件,获得训练完成的图像处理模型。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至7中任一项所述的方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110355199.XA CN113065593A (zh) | 2021-04-01 | 2021-04-01 | 模型训练方法、装置、计算机设备和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110355199.XA CN113065593A (zh) | 2021-04-01 | 2021-04-01 | 模型训练方法、装置、计算机设备和存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113065593A true CN113065593A (zh) | 2021-07-02 |
Family
ID=76565355
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110355199.XA Pending CN113065593A (zh) | 2021-04-01 | 2021-04-01 | 模型训练方法、装置、计算机设备和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113065593A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113326832A (zh) * | 2021-08-04 | 2021-08-31 | 北京的卢深视科技有限公司 | 模型训练、图像处理方法、电子设备及存储介质 |
CN114049634A (zh) * | 2022-01-12 | 2022-02-15 | 深圳思谋信息科技有限公司 | 一种图像识别方法、装置、计算机设备和存储介质 |
CN117746381A (zh) * | 2023-12-12 | 2024-03-22 | 北京迁移科技有限公司 | 位姿估计模型配置方法及位姿估计方法 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180114055A1 (en) * | 2016-10-25 | 2018-04-26 | VMAXX. Inc. | Point to Set Similarity Comparison and Deep Feature Learning for Visual Recognition |
CN110532880A (zh) * | 2019-07-29 | 2019-12-03 | 深圳大学 | 样本筛选及表情识别方法、神经网络、设备及存储介质 |
-
2021
- 2021-04-01 CN CN202110355199.XA patent/CN113065593A/zh active Pending
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180114055A1 (en) * | 2016-10-25 | 2018-04-26 | VMAXX. Inc. | Point to Set Similarity Comparison and Deep Feature Learning for Visual Recognition |
CN110532880A (zh) * | 2019-07-29 | 2019-12-03 | 深圳大学 | 样本筛选及表情识别方法、神经网络、设备及存储介质 |
Non-Patent Citations (1)
Title |
---|
DAMING SHI ET AL.: "A conditional Triplet loss for few-shot learning and its application to image co-segmentation", 《NEURAL NETWORKS》, pages 54 - 62 * |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113326832A (zh) * | 2021-08-04 | 2021-08-31 | 北京的卢深视科技有限公司 | 模型训练、图像处理方法、电子设备及存储介质 |
CN113326832B (zh) * | 2021-08-04 | 2021-12-17 | 北京的卢深视科技有限公司 | 模型训练、图像处理方法、电子设备及存储介质 |
CN114049634A (zh) * | 2022-01-12 | 2022-02-15 | 深圳思谋信息科技有限公司 | 一种图像识别方法、装置、计算机设备和存储介质 |
CN117746381A (zh) * | 2023-12-12 | 2024-03-22 | 北京迁移科技有限公司 | 位姿估计模型配置方法及位姿估计方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US10726244B2 (en) | Method and apparatus detecting a target | |
CN113065593A (zh) | 模型训练方法、装置、计算机设备和存储介质 | |
CN109086711B (zh) | 人脸特征分析方法、装置、计算机设备和存储介质 | |
US10832032B2 (en) | Facial recognition method, facial recognition system, and non-transitory recording medium | |
CN110610143B (zh) | 多任务联合训练的人群计数网络方法、系统、介质及终端 | |
CN110245714B (zh) | 图像识别方法、装置及电子设备 | |
US9129152B2 (en) | Exemplar-based feature weighting | |
CN111401219B (zh) | 一种手掌关键点检测方法和装置 | |
CN109840524A (zh) | 文字的类型识别方法、装置、设备及存储介质 | |
CN112560710B (zh) | 一种用于构建指静脉识别系统的方法及指静脉识别系统 | |
CN113179421B (zh) | 视频封面选择方法、装置、计算机设备和存储介质 | |
CN111274999A (zh) | 数据处理、图像处理方法、装置及电子设备 | |
CN111292377A (zh) | 目标检测方法、装置、计算机设备和存储介质 | |
CN110796250A (zh) | 应用于卷积神经网络的卷积处理方法、系统及相关组件 | |
CN115690672A (zh) | 异常图像识别方法、装置、计算机设备和存储介质 | |
CN112749737A (zh) | 图像分类方法及装置、电子设备、存储介质 | |
CN111862040A (zh) | 人像图片质量评价方法、装置、设备及存储介质 | |
CN114519401A (zh) | 一种图像分类方法及装置、电子设备、存储介质 | |
CN115731442A (zh) | 图像处理方法、装置、计算机设备和存储介质 | |
CN115758271A (zh) | 数据处理方法、装置、计算机设备和存储介质 | |
CN115631370A (zh) | 一种基于卷积神经网络的mri序列类别的识别方法及装置 | |
CN114677578A (zh) | 确定训练样本数据的方法和装置 | |
CN113688655B (zh) | 干扰信号的识别方法、装置、计算机设备和存储介质 | |
CN114519729A (zh) | 图像配准质量评估模型训练方法、装置和计算机设备 | |
CN113269176B (zh) | 图像处理模型训练、图像处理方法、装置和计算机设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |