CN114118370A - 模型训练方法、电子设备和计算机可读存储介质 - Google Patents

模型训练方法、电子设备和计算机可读存储介质 Download PDF

Info

Publication number
CN114118370A
CN114118370A CN202111401804.9A CN202111401804A CN114118370A CN 114118370 A CN114118370 A CN 114118370A CN 202111401804 A CN202111401804 A CN 202111401804A CN 114118370 A CN114118370 A CN 114118370A
Authority
CN
China
Prior art keywords
model
training
loss
vectors
loss function
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111401804.9A
Other languages
English (en)
Inventor
浦煜
胡长胜
何武
付贤强
朱海涛
户磊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hefei Dilusense Technology Co Ltd
Original Assignee
Beijing Dilusense Technology Co Ltd
Hefei Dilusense Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Dilusense Technology Co Ltd, Hefei Dilusense Technology Co Ltd filed Critical Beijing Dilusense Technology Co Ltd
Priority to CN202111401804.9A priority Critical patent/CN114118370A/zh
Publication of CN114118370A publication Critical patent/CN114118370A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Biophysics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Evolutionary Biology (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例涉及计算机视觉技术领域,公开了一种模型训练方法、电子设备和计算机可读存储介质,该方法包括:获取待训练的模型和标注有标签的训练样本;将训练样本输入至模型中,获取模型输出的特征向量;根据模型输出的特征向量和标签构建损失函数,对模型进行迭代训练,直到损失函数收敛,得到训练完成的模型;其中,损失函数包括第一损失项和第二损失项,第一损失项为softmax函数,第二损失项用于约束模型分出的各类别的类别中心向量之间的最小距离不小于平均距离,平均距离为各类别的类别中心向量均匀分布时两两之间的距离,从而避免因不同类别的样本数据量不均衡而训练出带有偏向性的图像分类模型,提升模型分类的准确性。

Description

模型训练方法、电子设备和计算机可读存储介质
技术领域
本申请实施例涉及计算机视觉技术领域,特别涉及一种模型训练方法、电子设备和计算机可读存储介质。
背景技术
近些年来,随着大数据的爆发性增长,相关人工智能技术都得到了长足发展,其中以计算机视觉技术取得的成果最为显著,其在工业生产、医疗、安防、金融支付、公共服务等领域具有十分广泛的应用,而图像分类作为计算机视觉的基础任务之一,其主要目的是区分图像中目标对象的类别,是理解图像内容的重要一环。
当前的深度神经网络在一系列图像分类任务上已取得优异的性能,训练图像分类模型时一般使用标配的softmax交叉熵损失函数,虽然softmax交叉熵损失函数计算较为简单,但其在训练过程中不能强制约束不同类别的可分性,这导致使用数据量不平衡的训练样本对图像分类模型进行训练时,训练出的图像分类模型存在偏向性,即模型的分类结果更偏向于数据量较多的训练样本的类别,这降低了图像分类模型的分类的准确性。
发明内容
本申请实施例的目的在于提供一种模型训练方法、电子设备和计算机可读存储介质,可以避免因不同类别的样本数据量不均衡而导致训练出带有偏向性的图像分类模型,从而大幅提升模型分类的准确性。
为解决上述技术问题,本申请的实施例提供了一种模型训练方法,包括以下步骤:获取待训练的模型和标注有标签的训练样本;其中,所述标签用于表征所述训练样本的特征向量将所述训练样本输入至所述模型中,获取所述模型输出的特征向量;根据所述模型输出的特征向量和所述标签构建损失函数,对所述模型进行迭代训练,直到所述损失函数收敛,得到训练完成的模型;其中,所述损失函数包括第一损失项和第二损失项,所述第一损失项为softmax函数,所述第二损失项用于约束所述模型分出的各类别的类别中心向量之间的最小距离不小于平均距离,所述平均距离为所述各类别的类别中心向量均匀分布时两两之间的距离。
本申请的实施例还提供了一种电子设备,包括:至少一个处理器;以及,与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行上述模型训练方法。
本申请的实施例还提供了一种计算机可读存储介质,存储有计算机程序,所述计算机程序被处理器执行时实现上述模型训练方法。
本申请实施例提供的模型训练方法、电子设备和计算机可读存储介质,获取待训练的模型和标注有标签的训练样本,其中,训练样本上标注的标签用于表征训练样本的特征向量,服务器将训练样本输入至模型中,获取模型输出的特征向量,再根据模型输出的特征向量和训练样本上标注的标签构建损失函数,对模型进行迭代训练,直到损失函数收敛,得到训练完成的模型,其中,构建的损失函数包括第一损失项和第二损失项,第一损失项为softmax函数,第二损失项用于约束模型分出的各类别的类别中心向量之间的最小距离不小于平均距离,平均距离为各类别的类别中心向量均匀分布时两两之间的距离,考虑到只使用标配的softmax函数作为损失函数对模型进行训练,在不同类别的样本的数据量不均衡的情况下,训练出的模型存在偏向性,而本申请不仅设置softmax函数作为第一损失项目,还设置了用于约束模型的各类别的类别中心向量之间的最小距离不小于平均距离的第二损失项,迫使模型分出的各类别尽量均匀分布,即使不同类别的样本数据量不均衡也不会训练出带有偏向性的图像分类模型,从而提升模型分类的准确性,同时,与其他替代softmax损失函数的Sphereface、Cosface、Arcface等损失函数相比,本申请的实施例没有额外的可学习参数,可以直接嵌入使用,提升模型的分类效果。
另外,所述训练样本为若干个,所述对所述模型进行迭代训练具体为基于小批量梯度下降法对所述模型进行迭代训练,所述损失函数还包括第三损失项,所述第三损失项用于约束不同id的训练样本的特征向量之间的夹角不小于π/2,在基于小批量梯度下降法对模型进行迭代训练时,由于一次输入至模型的训练样本数量大于1,即minibatch大于1,mininatch中的样本可能分属不同id,而且可能是不平衡的,因此本申请构建的损失函数中还设置有第三损失项,第三损失项可以约束不同id的训练样本的特征向量之间的夹角不小于π/2,提升类内紧凑性,从而进一步提升模型分类的准确性。
另外,所述损失函数中的第三损失项通过以下步骤构建:获取所述若干个训练样本中各训练样本的id;根据所述各训练样本的id,确定所述模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数;对所述模型输出的不同id的训练样本的特征向量进行L2范数归一化,得到归一化的不同id的训练样本的特征向量;根据所述总次数、所述归一化的不同id的训练样本的特征向量和预设参数,构建所述损失函数中的第三损失项,基于模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数,以及归一化的不同id的训练样本的特征向量构建第三损失项,可以从样本特征的角度对特征分布直接进行约束,从而使得不同类别的训练样本尽量分开,间接提升了类内紧凑性,从而更好地对模型进行训练。
另外,所述损失函数中的第二损失项通过以下步骤构建:获取所述模型分出的类别的个数和所述模型分出的各类别的分类权重向量;对所述各类别的分类权重向量分别进行转置,并分别对转置后的各类别的分类权重向量进行L2范数归一化,得到所述模型的各类别的类别中心向量;根据所述类别的个数和所述各类别的类别中心向量,计算所述平均距离;其中,所述平均距离为所述各类别的类别中心向量均匀分布时两两之间的距离;根据所述平均距离和所述各类别的类别中心向量,构建所述损失函数中的第二损失项,考虑到标配的Softmax损失函数不能保证对模型分出的每个类别都有一个均匀分类区域,这导致样本量较少的类别其在特征空间中的类别中心向量与其它类别的类别中心向量的距离过于接近,从而导致实际分类结果较差,理想情况下,各类别的类别中心向量在特征空间中应该是均匀分布的,因此本实施例根据平均距离和各类别的类别中心向量,构建第二损失项,强迫各类别的类别中心向量在特征空间中呈现均匀分布或近似均匀分布,从而更好地对模型进行训练。
另外,所述将所述训练样本输入至所述模型中,获取所述模型输出的特征向量,包括:对所述训练样本进行预处理,得到预处理后的训练样本;其中,所述预处理包括数据裁剪、数据增强和数据归一化;将所述预处理后的训练样本输入至所述模型中,获取所述模型输出的特征向量,对训练样本进行包括数据裁剪、数据增强和数据归一化的预处理,也可以在一定程度上削弱不同类别的样本数据量不均衡的问题,用预处理后的训练样本进行训练,可以进一步提升模型分类的准确性。
附图说明
一个或多个实施例通过与之对应的附图中的图片进行示例性说明,这些示例性说明并不构成对实施例的限定。
图1是根据本申请的一个实施例的模型训练方法的流程图一;
图2是根据本申请的一个实施例中,构建损失函数中的第三损失项的流程图;
图3是根据本申请的一个实施例中,构建损失函数中的第二损失项的流程图;
图4是根据本申请的另一个实施例的模型训练方法的流程图二;
图5是根据本申请的另一个实施例的电子设备的结构示意图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合附图对本申请的各实施例进行详细的阐述。然而,本领域的普通技术人员可以理解,在本申请各实施例中,为了使读者更好地理解本申请而提出了许多技术细节。但是,即使没有这些技术细节和基于以下各实施例的种种变化和修改,也可以实现本申请所要求保护的技术方案。以下各个实施例的划分是为了描述方便,不应对本申请的具体实现方式构成任何限定,各个实施例在不矛盾的前提下可以相互结合相互引用。
本申请的一个实施例涉及一种模型训练方法,应用于电子设备,其中,电子设备可以为终端或服务器,本实施例以及以下个各个实施例中电子设备以服务器为例进行说明,下面对本实施例的模型训练方法的实现细节进行具体的说明,以下内容仅为方便理解提供的实现细节,并非实施本方案的必须。
本实施例的模型训练方法的具体流程可以如图1所示,包括:
步骤101,获取待训练的模型和标注有标签的训练样本。
具体而言,服务器在进行模型训练时,先获取待训练的模型和标注有标签的训练样本,训练样本上标注的标签可以用于表征该训练样本的特征向量。
在一个例子中,待训练的模型可以为人脸识别模型等图像分类模型,比如VGGNet、GoogleNet、ResNet、DenseNet、ShuffleNet等分类模型。
在一个例子中,服务器可以根据待训练的模型和分类任务的实际应用场景,获取标注有标签的训练样本,比如MegaFace、CASIA-WebFace等人脸数据集,再比如MNIST、CIFAR100、ImageNet等普通分类任务数据集。
步骤102,将训练样本输入至模型中,获取模型输出的特征向量。
在具体实现中,服务器获取到训练的模型和标注有标签的训练样本后,可以将标注有标签的训练样本输入至模型中,模型可以根据自身当前的网络参数输出训练样本的特征表达,一般为训练样本的D维的特征向量,训练样本的D维的特征向量可以理解为训练样本在高维特征空间中的特征点,服务器可以获取模型输出的该训练样本的D维的特征向量。
在一个例子中,服务器基于小批量梯度下降法对模型进行训练,服务器一次向模型中输入N个训练样本,即minibatch为N,训练样本上标注的标签可以用yi,i∈[1,N]表示,yi即第i个训练样本上标注的特征向量,模型输出的特征向量可以用xi,i∈[1,N]表示,xi即第i个训练样本的D维的特征向量。
步骤103,根据模型输出的特征向量和标签构建损失函数,对模型进行迭代训练,直到损失函数收敛,得到训练完成的模型。
具体而言,服务器获取到模型输出的特征向量后,可以根据模型输出的特征向量和训练样本上标注的标签建损失函数,基于构建的损失函数对模型进行迭代训练,直到损失函数收敛,即可得到训练完成的模型,构建的损失函数包括第一损失项和第二损失项,第一损失项为softmax函数,第二损失项用于约束模型分出的各类别的类别中心向量之间的最小距离不小于平均距离,其中,平均距离为各类别的类别中心向量在均匀分布时两两之间的距离。
在一个例子中,服务器基于小批量梯度下降法对模型进行训练,服务器一次向模型中输入N个训练样本,即mini batch为N,服务器构建的损失函数中的第一损失项可以通过以下公式表示:
Figure BDA0003364928020000051
式中,N为训练样本的个数,即minibatch,xn为模型输出的第n个训练样本的特征向量,yn为第n个训练样本上标注的标签,K为模型分出的类别的个数,w为模型的分类权重项,b为模型的网络偏置项,Lsoftmax为服务器构建的损失函数中的第一损失项。
在一个例子中,服务器构建的第二损失项用Lsoftmax表示,服务器构建的第二损失项用Lclass表示,服务器构建的损失函数可以表示为:L=Lsoftmax+Lclass
在具体实现中,服务器以构建的损失函数作为监督,使用小批量梯度下降法对模型进行迭代训练,并实时监测损失函数是否收敛,当服务器确定损失函数收敛后,可以保存模型此时的各种参数,如权重和偏置等,发布训练完成的模型。
本实施例,相较于使用标配的softmax交叉熵损失函数对图像分类模型进行训练的技术方案而言,服务器先获取待训练的模型和标注有标签的训练样本,其中,训练样本上标注的标签用于表征训练样本的特征向量,服务器将训练样本输入至模型中,获取模型输出的特征向量,再根据模型输出的特征向量和训练样本上标注的标签构建损失函数,对模型进行迭代训练,直到损失函数收敛,得到训练完成的模型,其中,构建的损失函数包括第一损失项和第二损失项,第一损失项为softmax函数,第二损失项用于约束模型分出的各类别的类别中心向量之间的最小距离不小于平均距离,平均距离为各类别的类别中心向量均匀分布时两两之间的距离,考虑到只使用标配的softmax函数作为损失函数对模型进行训练,在不同类别的样本的数据量不均衡的情况下,训练出的模型存在偏向性,而本申请不仅设置softmax函数作为第一损失项目,还设置了用于约束模型的各类别的类别中心向量之间的最小距离不小于平均距离的第二损失项,迫使模型分出的各类别尽量均匀分布,即使不同类别的样本数据量不均衡也不会训练出带有偏向性的图像分类模型,从而提升模型分类的准确性,同时,与其他替代softmax损失函数的Sphereface、Cosface、Arcface等损失函数相比,本申请的实施例没有额外的可学习参数,可以直接嵌入使用,提升模型的分类效果。
在一个实施例中,服务器获取的标注有标签的训练样本为若干个,服务器基于小批量梯度下降法对模型进行迭代训练,一次向模型中输入若干个训练样本,输入到模型中的若干个训练样本可以分属于不同的id,服务器根据模型输出的特征向量和标签构建的损失函数还包括第三损失项,第三损失项用于约束不同id的训练样本的特征向量之间的夹角不小于π/2。
在具体实现中,基于小批量梯度下降法对模型进行迭代训练时,由于一次输入至模型的训练样本数量大于1,即mini batch大于1,mini natch中的样本可能分属不同id,而且可能是不平衡的,因此本申请构建的损失函数中还设置有第三损失项,第三损失项可以约束不同id的训练样本的特征向量之间的夹角不小于π/2,提升类内紧凑性,从而进一步提升模型分类的准确性。
在一个例子中,训练样本为人脸图像样本,服务器获取一个批次共32张人脸图像样本,训练样本的id可以为人脸身份,如32张人脸图像样本中有12张人脸图像样本属于人脸甲,即这12张人脸图像样本的id为人脸甲,有10张人脸图像样本属于人脸乙,即这10张人脸图像样本的id为人脸乙,还有10张人脸图像样本属于人脸丙,即这10张人脸图像样本的id为人脸丙。
在一个例子中,服务器构建的第三损失项用Lfeature表示,服务器构建的损失函数可以表示为:L=Lsoftmax+Lclass+Lfeature
在一个实施例中,服务器获取的标注有标签的训练样本为若干个,服务器基于小批量梯度下降法对模型进行迭代训练,一次向模型中输入若干个训练样本,输入到模型中的若干个训练样本可以分属于不同的id,服务器构建损失函数中的第三损失项可以通过如图2所示的各步骤实现,具体包括:
步骤201,获取若干个训练样本中各训练样本的id。
在具体实现中,服务器基于小批量梯度下降法对模型进行迭代训练,一次向模型中输入N个训练样本,这N个训练样本可以分属于不同的id,服务器可以获取这N个训练样本各自分属的id。
在一个例子中,训练样本为人脸图像样本,服务器获取一个批次共32张人脸图像样本,训练样本的id可以为人脸身份,如32张人脸图像样本中有12张人脸图像样本属于人脸甲,即这12张人脸图像样本的id为人脸甲,有10张人脸图像样本属于人脸乙,即这10张人脸图像样本的id为人脸乙,还有10张人脸图像样本属于人脸丙,即这10张人脸图像样本的id为人脸丙。
步骤202,根据各训练样本的id,确定模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数。
在一个例子中,训练样本为人脸图像样本,服务器获取一个批次共8张人脸图像样本,训练样本的id可以为人脸身份,如8张人脸图像样本中有4张人脸图像样本属于人脸甲,即这4张人脸图像样本的id为人脸甲,分别记为甲1、甲2、甲3和甲4,有3张人脸图像样本属于人脸乙,即这3张人脸图像样本的id为人脸乙,分别记为乙1、乙2和乙3,还有1张人脸图像样本属于人脸丙,即这1张人脸图像样本的id为人脸丙,记为丙1,那么模型输出的甲1的特征向量,与模型输出的乙1、乙2、乙3和丙1的特征向量就是不同id的训练样本的特征向量,那么甲1的特征向量分别于乙1、乙2、乙3和丙1的特征向量计算相似度,模型输出的甲2的特征向量,与模型输出的乙1、乙2、乙3和丙1的特征向量是不同id的训练样本的特征向量,那么甲2的特征向量分别于乙1、乙2、乙3和丙1的特征向量计算相似度,以此类推,服务器确定模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数|B|=4+4+4+4+1+1+1=19次。
步骤203,对模型输出的不同id的训练样本的特征向量进行L2范数归一化,得到归一化的不同id的训练样本的特征向量。
在具体实现中,服务器将N个训练样本输入至模型中后,可以对模型输出的不同id的训练样本的特征向量进行L2范数归一化,得到归一化的不同id的训练样本的特征向量,其中,服务器也可以根据实际情况选用其他归一化方法对模型输出的不同id的训练样本的特征向量进行归一化,得到归一化的不同id的训练样本的特征向量。
步骤204,根据模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数、归一化的不同id的训练样本的特征向量和预设参数,构建损失函数中的第三损失项。
具体而言,服务器在得到归一化的不同id的训练样本的特征向量后,可以根据模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数、归一化的不同id的训练样本的特征向量和预设参数,构建损失函数中的第三损失项,其中,预设参数可以由本领域的技术人员根据实际需要进行设置,本申请的实施例对此不做具体限定。
在一个例子中,服务器可以通过以下公式,根据模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数、归一化的不同id的训练样本的特征向量和预设参数,构建损失函数中的第三损失项:
Figure BDA0003364928020000081
式中,|B|为模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数,
Figure BDA0003364928020000082
为归一化的第a种id的训练样本的特征向量,
Figure BDA0003364928020000083
为归一化的第b种id的训练样本的特征向量,λ为预设参数,Lfeature为第三损失项。
本实施例,所述损失函数中的第三损失项通过以下步骤构建:获取所述若干个训练样本中各训练样本的id;根据所述各训练样本的id,确定所述模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数;对所述模型输出的不同id的训练样本的特征向量进行L2范数归一化,得到归一化的不同id的训练样本的特征向量;根据所述总次数、所述归一化的不同id的训练样本的特征向量和预设参数,构建所述损失函数中的第三损失项,基于模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数,以及归一化的不同id的训练样本的特征向量构建第三损失项,可以从样本特征的角度对特征分布直接进行约束,从而使得不同类别的训练样本尽量分开,间接提升了类内紧凑性,从而更好地对模型进行训练。
在一个实施例中,服务器构建损失函数中的第二损失项可以通过如图3所示的各步骤实现,具体包括:
步骤301,获取模型分出的类别的个数和模型分出的各类别的分类权重向量。
步骤302,对各类别的分类权重向量分别进行转置,并分别对转置后的各类别的分类权重向量进行L2范数归一化,得到模型的各类别的类别中心向量。
在具体实现中,服务器构建损失函数中的第二损失项时,可以先获取模型分出的类别的个数K和模型分出的各类别的分类权重向量wk,k∈[1,K],分类权重向量wk实际上是一个D×K维度的向量,服务器对各类别的分类权重向量wk分别进行转置,得到转置后的各类别的分类权重向量
Figure BDA0003364928020000084
并对转置后的各类别的分类权重向量
Figure BDA0003364928020000085
进行L2范数归一化,得到模型的各类别的类别中心向量
Figure BDA0003364928020000086
类别中心向量
Figure BDA0003364928020000087
实际上是一个K×D维度的向量。
步骤303,根据模型分出的类别的个数和各类别的类别中心向量,计算平均距离。
在具体实现中,服务器得到各类别的类别中心向量后,可以根据模型分出的类别的个数K和各类别的类别中心向量
Figure BDA0003364928020000088
计算平均距离μ,平均距离μ为各类别的类别中心向量
Figure BDA0003364928020000089
在均匀分布情况下两两之间的距离。
在一个例子中,服务器计算出的平均距离μ可以通过以下公式表示:
Figure BDA0003364928020000091
式中,μ为平均距离,K为模型分出的类别的个数,
Figure BDA0003364928020000092
为第j类的类别中心向量,
Figure BDA0003364928020000093
为第k类的类别中心向量。
步骤304,根据平均距离和各类别的类别中心向量,构建损失函数中的第二损失项。
在具体实现中,服务器计算出平均距离后,可以根据平均距离μ和各类别的类别中心向量
Figure BDA0003364928020000094
构建损失函数中的第二损失项。
在一个例子中,服务器可以通过以下公式,根据平均距离μ和各类别的类别中心向量
Figure BDA0003364928020000095
构建损失函数中的第二损失项:
Figure BDA0003364928020000096
Figure BDA0003364928020000097
式中,μ为平均距离,K为类别的个数,(K2-K)/2表示不同类别的类别中心向量之间两两计算相似度的次数,
Figure BDA0003364928020000098
为第j类的类别中心向量,
Figure BDA0003364928020000099
为第k类的类别中心向量,Lclass为第二损失项。
本实施例,所述损失函数中的第二损失项通过以下步骤构建:获取所述模型分出的类别的个数和所述模型分出的各类别的分类权重向量;对所述各类别的分类权重向量分别进行转置,并分别对转置后的各类别的分类权重向量进行L2范数归一化,得到所述模型的各类别的类别中心向量;根据所述类别的个数和所述各类别的类别中心向量,计算所述平均距离;其中,所述平均距离为所述各类别的类别中心向量均匀分布时两两之间的距离;根据所述平均距离和所述各类别的类别中心向量,构建所述损失函数中的第二损失项,考虑到标配的Softmax损失函数不能保证对模型分出的每个类别都有一个均匀分类区域,这导致样本量较少的类别其在特征空间中的类别中心向量与其它类别的类别中心向量的距离过于接近,从而导致实际分类结果较差,理想情况下,各类别的类别中心向量在特征空间中应该是均匀分布的,因此本实施例根据平均距离和各类别的类别中心向量,构建第二损失项,强迫各类别的类别中心向量在特征空间中呈现均匀分布或近似均匀分布,从而更好地对模型进行训练。
本申请的另一个实施例涉及一种模型训练方法,下面对本实施例的模型训练方法的实现细节进行具体的说明,以下内容仅为方便理解提供的实现细节,并非实施本方案的必须,本实施例的模型训练方法的具体流程可以如图4所示,包括:
步骤401,获取待训练的模型和标注有标签的训练样本。
其中,步骤401与步骤101大致相同,此处不再赘述。
步骤402,对训练样本进行预处理,得到预处理后的训练样本,预处理包括数据裁剪、数据增强和数据归一化。
步骤403,将预处理后的训练样本输入至模型中,获取模型输出的特征向量。
在具体实现中,服务器获取到待训练的模型和标注有标签的训练样本后,可以对训练样本进行包括数据裁剪、数据增强和数据归一化在内的预处理,数据裁剪的目的在于使各训练样本之间可以对齐,数据增强可以包括随机噪声、随机对比度等,数据归一化的目的在于简化计算过程。
步骤404,根据模型输出的特征向量和标签构建损失函数,对模型进行迭代训练,直到损失函数收敛,得到训练完成的模型。
其中,步骤404与步骤103大致相同,此处不再赘述。
本实施例,所述将所述训练样本输入至所述模型中,获取所述模型输出的特征向量,包括:对所述训练样本进行预处理,得到预处理后的训练样本;其中,所述预处理包括数据裁剪、数据增强和数据归一化;将所述预处理后的训练样本输入至所述模型中,获取所述模型输出的特征向量,对训练样本进行包括数据裁剪、数据增强和数据归一化的预处理,也可以在一定程度上削弱不同类别的样本数据量不均衡的问题,用预处理后的训练样本进行训练,可以进一步提升模型分类的准确性。
上面各种方法的步骤划分,只是为了描述清楚,实现时可以合并为一个步骤或者对某些步骤进行拆分,分解为多个步骤,只要包括相同的逻辑关系,都在本专利的保护范围内;对算法中或者流程中添加无关紧要的修改或者引入无关紧要的设计,但不改变其算法和流程的核心设计都在该专利的保护范围内。
本申请另一个实施例涉及一种电子设备,如图5所示,包括:至少一个处理器501;以及,与所述至少一个处理器501通信连接的存储器502;其中,所述存储器502存储有可被所述至少一个处理器501执行的指令,所述指令被所述至少一个处理器501执行,以使所述至少一个处理器501能够执行上述各实施例中的模型训练方法。
其中,存储器和处理器采用总线方式连接,总线可以包括任意数量的互联的总线和桥,总线将一个或多个处理器和存储器的各种电路连接在一起。总线还可以将诸如外围设备、稳压器和功率管理电路等之类的各种其他电路连接在一起,这些都是本领域所公知的,因此,本文不再对其进行进一步描述。总线接口在总线和收发机之间提供接口。收发机可以是一个元件,也可以是多个元件,比如多个接收器和发送器,提供用于在传输介质上与各种其他装置通信的单元。经处理器处理的数据通过天线在无线介质上进行传输,进一步,天线还接收数据并将数据传送给处理器。
处理器负责管理总线和通常的处理,还可以提供各种功能,包括定时,外围接口,电压调节、电源管理以及其他控制功能。而存储器可以被用于存储处理器在执行操作时所使用的数据。
本申请另一个实施例涉及一种计算机可读存储介质,存储有计算机程序。计算机程序被处理器执行时实现上述方法实施例。
即,本领域技术人员可以理解,实现上述实施例方法中的全部或部分步骤是可以通过程序来指令相关的硬件来完成,该程序存储在一个存储介质中,包括若干指令用以使得一个设备(可以是单片机,芯片等)或处理器(processor)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-OnlyMemory,简称:ROM)、随机存取存储器(Random Access Memory,简称:RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
本领域的普通技术人员可以理解,上述各实施例是实现本申请的具体实施例,而在实际应用中,可以在形式上和细节上对其作各种改变,而不偏离本申请的精神和范围。

Claims (10)

1.一种模型训练方法,其特征在于,包括:
获取待训练的模型和标注有标签的训练样本;其中,所述标签用于表征所述训练样本的特征向量;
将所述训练样本输入至所述模型中,获取所述模型输出的特征向量;
根据所述模型输出的特征向量和所述标签构建损失函数,对所述模型进行迭代训练,直到所述损失函数收敛,得到训练完成的模型;其中,所述损失函数包括第一损失项和第二损失项,所述第一损失项为softmax函数,所述第二损失项用于约束所述模型分出的各类别的类别中心向量之间的最小距离不小于平均距离,所述平均距离为所述各类别的类别中心向量均匀分布时两两之间的距离。
2.根据权利要求1所述的模型训练方法,其特征在于,所述训练样本为若干个,所述对所述模型进行迭代训练具体为基于小批量梯度下降法对所述模型进行迭代训练,所述损失函数还包括第三损失项,所述第三损失项用于约束不同id的训练样本的特征向量之间的夹角不小于π/2。
3.根据权利要求2所述的模型训练方法,其特征在于,所述损失函数中的第三损失项通过以下步骤构建:
获取所述若干个训练样本中各训练样本的id;
根据所述各训练样本的id,确定所述模型输出的不同id的训练样本的特征向量之间两两计算相似度的总次数;
对所述模型输出的不同id的训练样本的特征向量进行L2范数归一化,得到归一化的不同id的训练样本的特征向量;
根据所述总次数、所述归一化的不同id的训练样本的特征向量和预设参数,构建所述损失函数中的第三损失项。
4.根据权利要求3所述的模型训练方法,其特征在于,通过以下公式,根据所述总次数、所述归一化的不同id的训练样本的特征向量和预设参数,构建所述损失函数中的第三损失项:
Figure FDA0003364928010000011
其中,|B|为所述总次数,
Figure FDA0003364928010000012
为归一化的第a种id的训练样本的特征向量,
Figure FDA0003364928010000013
为归一化的第b种id的训练样本的特征向量,λ为所述预设参数,Lfeature为所述第三损失项。
5.根据权利要求1至4种任一项所述的模型训练方法,其特征在于,所述损失函数中的第二损失项通过以下步骤构建:
获取所述模型分出的类别的个数和所述模型分出的各类别的分类权重向量;
对所述各类别的分类权重向量分别进行转置,并分别对转置后的各类别的分类权重向量进行L2范数归一化,得到所述模型的各类别的类别中心向量;
根据所述类别的个数和所述各类别的类别中心向量,计算所述平均距离;其中,所述平均距离为所述各类别的类别中心向量均匀分布时两两之间的距离;
根据所述平均距离和所述各类别的类别中心向量,构建所述损失函数中的第二损失项。
6.根据权利要求5所述的模型训练方法,其特征在于,通过以下公式,根据所述平均距离和所述各类别的类别中心向量,构建所述损失函数中的第二损失项:
Figure FDA0003364928010000021
Figure FDA0003364928010000022
其中,μ为所述平均距离,K为所述类别的个数,
Figure FDA0003364928010000023
为第j类的类别中心向量,
Figure FDA0003364928010000024
为第k类的类别中心向量,Lclass为所述第二损失项。
7.根据权利要求1至4中任一项所述的模型训练方法,其特征在于,所述训练样本为若干个,通过以下公式表示所述损失函数中的第一损失项:
Figure FDA0003364928010000025
其中,N为所述训练样本的个数,xn为所述模型输出的第n个训练样本的特征向量,yn为所述第n个训练样本的标签,K为所述模型分出的类别的个数,w为所述模型的分类权重项,b为所述模型的网络偏置项,Lsoftmax为所述第一损失项。
8.根据权利要求1至4中任一项所述的模型训练方法,其特征在于,所述将所述训练样本输入至所述模型中,获取所述模型输出的特征向量,包括:
对所述训练样本进行预处理,得到预处理后的训练样本;其中,所述预处理包括数据裁剪、数据增强和数据归一化;
将所述预处理后的训练样本输入至所述模型中,获取所述模型输出的特征向量。
9.一种电子设备,其特征在于,包括:
至少一个处理器;以及,
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行如权利要求1至8中任一所述的模型训练方法。
10.一种计算机可读存储介质,存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至8中任一项所述的模型训练方法。
CN202111401804.9A 2021-11-19 2021-11-19 模型训练方法、电子设备和计算机可读存储介质 Pending CN114118370A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111401804.9A CN114118370A (zh) 2021-11-19 2021-11-19 模型训练方法、电子设备和计算机可读存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111401804.9A CN114118370A (zh) 2021-11-19 2021-11-19 模型训练方法、电子设备和计算机可读存储介质

Publications (1)

Publication Number Publication Date
CN114118370A true CN114118370A (zh) 2022-03-01

Family

ID=80371622

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111401804.9A Pending CN114118370A (zh) 2021-11-19 2021-11-19 模型训练方法、电子设备和计算机可读存储介质

Country Status (1)

Country Link
CN (1) CN114118370A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115577258A (zh) * 2022-09-08 2023-01-06 中国电信股份有限公司 振动信号识别模型训练方法、电机故障检测方法及装置
CN116912920A (zh) * 2023-09-12 2023-10-20 深圳须弥云图空间科技有限公司 表情识别方法及装置

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112052789A (zh) * 2020-09-03 2020-12-08 腾讯科技(深圳)有限公司 人脸识别方法、装置、电子设备及存储介质
CN112949780A (zh) * 2020-04-21 2021-06-11 佳都科技集团股份有限公司 特征模型训练方法、装置、设备及存储介质

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112949780A (zh) * 2020-04-21 2021-06-11 佳都科技集团股份有限公司 特征模型训练方法、装置、设备及存储介质
CN112052789A (zh) * 2020-09-03 2020-12-08 腾讯科技(深圳)有限公司 人脸识别方法、装置、电子设备及存储介质

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
MUNAWAR HAYAT 等,: "Gaussian Affinity for Max-Margin Class Imbalanced Learning", 《2019 IEEE/CVF INTERNATIONAL CONFERENCE ON COMPUTER VISION (ICCV)》 *
WEIYANG LIU 等,: "SphereFace: Deep Hypersphere Embedding for Face Recognition", 《2017 IEEE CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION (CVPR)》 *
明悦,: "《多源视觉信息感知与识别》", 31 August 2020 *

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115577258A (zh) * 2022-09-08 2023-01-06 中国电信股份有限公司 振动信号识别模型训练方法、电机故障检测方法及装置
CN116912920A (zh) * 2023-09-12 2023-10-20 深圳须弥云图空间科技有限公司 表情识别方法及装置
CN116912920B (zh) * 2023-09-12 2024-01-05 深圳须弥云图空间科技有限公司 表情识别方法及装置

Similar Documents

Publication Publication Date Title
CN111537945B (zh) 基于联邦学习的智能电表故障诊断方法及设备
WO2021143396A1 (zh) 利用文本分类模型进行分类预测的方法及装置
CN111523621A (zh) 图像识别方法、装置、计算机设备和存储介质
CN110377587B (zh) 基于机器学习的迁移数据确定方法、装置、设备及介质
CN110210625B (zh) 基于迁移学习的建模方法、装置、计算机设备和存储介质
CN114118370A (zh) 模型训练方法、电子设备和计算机可读存储介质
CN111507470A (zh) 一种异常账户的识别方法及装置
CN113449704B (zh) 人脸识别模型训练方法、装置、电子设备及存储介质
CN113822315A (zh) 属性图的处理方法、装置、电子设备及可读存储介质
Masood et al. Differential evolution based advised SVM for histopathalogical image analysis for skin cancer detection
CN113674087A (zh) 企业信用等级评定方法、装置、电子设备和介质
CN113434699A (zh) Bert模型的预训练方法、计算机装置和存储介质
CN112668482A (zh) 人脸识别训练方法、装置、计算机设备及存储介质
US20180137409A1 (en) Method of constructing an artifical intelligence super deep layer learning model, device, mobile terminal, and software program of the same
CN112541530B (zh) 针对聚类模型的数据预处理方法及装置
Igual et al. Supervised learning
CN116739787B (zh) 基于人工智能的交易推荐方法及系统
CN115795355B (zh) 一种分类模型训练方法、装置及设备
CN114005015B (zh) 图像识别模型的训练方法、电子设备和存储介质
Villegas-Cortez et al. Interest points reduction using evolutionary algorithms and CBIR for face recognition
CN115619541A (zh) 一种风险预测系统和方法
CN113011893B (zh) 数据处理方法、装置、计算机设备及存储介质
CN114937166A (zh) 图像分类模型构建方法、图像分类方法及装置、电子设备
CN113128615A (zh) 基于pca的bp神经网络对信息安全的检测系统、方法、应用
CN111400413A (zh) 一种确定知识库中知识点类目的方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
TA01 Transfer of patent application right

Effective date of registration: 20220520

Address after: 230091 room 611-217, R & D center building, China (Hefei) international intelligent voice Industrial Park, 3333 Xiyou Road, high tech Zone, Hefei, Anhui Province

Applicant after: Hefei lushenshi Technology Co.,Ltd.

Address before: 100083 room 3032, North B, bungalow, building 2, A5 Xueyuan Road, Haidian District, Beijing

Applicant before: BEIJING DILUSENSE TECHNOLOGY CO.,LTD.

Applicant before: Hefei lushenshi Technology Co.,Ltd.

TA01 Transfer of patent application right
RJ01 Rejection of invention patent application after publication

Application publication date: 20220301

RJ01 Rejection of invention patent application after publication