CN113011532A - 分类模型训练方法、装置、计算设备及存储介质 - Google Patents

分类模型训练方法、装置、计算设备及存储介质 Download PDF

Info

Publication number
CN113011532A
CN113011532A CN202110481964.2A CN202110481964A CN113011532A CN 113011532 A CN113011532 A CN 113011532A CN 202110481964 A CN202110481964 A CN 202110481964A CN 113011532 A CN113011532 A CN 113011532A
Authority
CN
China
Prior art keywords
prediction probability
probability distribution
classification model
class
prediction
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110481964.2A
Other languages
English (en)
Inventor
吴天博
王健宗
黄章成
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202110481964.2A priority Critical patent/CN113011532A/zh
Publication of CN113011532A publication Critical patent/CN113011532A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请实施例提供一种分类模型训练方法、装置、计算设备及存储介质,其中,方法包括将训练样本输入分类模型,得到训练样本属于每个类别的预测概率分布;计算目标损失和惩罚项,其中惩罚项用于指示预测概率分布中负类的离散程度。将目标损失与惩罚项之和记为总损失,根据总损失更新分类模型参数。在传统模型训练中,由于损失函数的局限性,忽略了模型对负类预测的准确度。本申请将负类分布的离散程度作为惩罚项引入模型损失构建新型损失函数,提升原有损失函数的性能,提高了模型的预测能力。

Description

分类模型训练方法、装置、计算设备及存储介质
技术领域
本申请涉及深度学习领域,具体涉及一种分类模型训练方法、装置、计算设备及存储介质。
背景技术
多分类任务的目的是对一个输入数据赋予合适的类别标签。在多分类任务中,数据的类别标签仅有一个,通过分类模型预测类别标签概率,将概率最大的标签作为数据的类别。
在分类模型的训练中使用损失函数来进行参数更新,多分类任务中常用的损失函数为交叉熵损失。但交叉熵损失函数只关心对于正类标签预测概率的准确性,在实际使用中导致训练出的模型预测准确度低,无法取得很好的效果。
发明内容
本申请提供一种分类模型训练方法、装置、计算设备及存储介质,将模型对负类标签预测的准确度纳入模型损失的计算,提高分类模型的预测能力。
第一方面,本申请提供一种分类模型训练方法,包括:将训练样本输入分类模型,得到训练样本属于每个类别的预测概率分布,预测概率分布为分类模型预测的训练样本属于每个类别的预测概率;根据预测概率分布与训练样本的实际标签分布,计算目标损失,实际标签分布为训练样本属于每个类别的实际概率,目标损失用于指示分类模型预测概率分布与实际标签分布之间的误差;根据预测概率分布中负类的预测概率分布,计算惩罚项,惩罚项用于指示预测概率分布中负类的离散程度;将目标损失与惩罚项之和记为总损失,根据总损失更新分类模型参数,得到训练好的分类模型。
分类模型为自监督预训练模型,需要使用训练样本对模型进行训练,根据模型预测结果计算模型损失,向减少模型损失的方向进行反向传播更新分类模型参数,从而提高分类模型的预测能力。
在一种可能的实现方式中,根据预测概率分布中负类的预测概率分布,计算惩罚项,包括:获取预测概率分布中负类的预测概率分布;根据负类的预测概率分布计算负类的预测概率分布的方差;根据分类标签数、正类标签数以及负类的预测概率分布的方差,确定惩罚项。
原有的损失函数只关注到了模型对正类的预测是否准确而忽略了负类分布的情况,出于减少模型误判的考虑,结果中正负类概率的差别更明显的预测效果更好。将负类预测概率分布的方差作为惩罚项,使得负类结果分布更均匀的结果模型损失更小。
在一种可能的实现方式中,根据预测概率分布中负类的预测概率分布,计算惩罚项,包括:获取预测概率分布中负类的预测概率分布;根据负类的预测概率分布计算负类的预测概率分布的极差,极差为负类的预测概率分布中最大预测概率与最小预测概率之差;根据负类的预测概率分布的极差,确定惩罚项。
原有的损失函数只关注到了模型对正类的预测是否准确而忽略了负类分布的情况,出于减少模型误判的考虑,结果中正负类概率的差别更明显的预测效果更好。将负类预测概率分布的极差作为惩罚项,使得负类结果分布更均匀的结果模型损失更小。
在一种可能的实现方式中,根据预测概率分布与训练样本的实际标签分布,计算目标损失包括:根据预测概率分布中第i个类别的预测概率与第i个类别的实际概率,计算交叉熵损失作为目标损失,其中,i的取值为1到N,N为分类标签数。
在一种可能的实现方式中,根据总损失更新分类模型参数,得到训练好的分类模型包括:基于总损失进行反向传播,得到分类模型中多个网络层的梯度;基于多个网络层的梯度,对多个网络层进行参数更新。
神经网络训练过程需要经过前向传播、反向传播和参数的更新不断迭代等过程,直至收敛(可以根据损失函数的值是否不再下降或者趋于稳定来判断是否收敛),从而获得训练好的神经网络。前向传播是从神经网络层的第一层向最后一层传播的过程,反向传播是从神经网络层的最后一层向第一层传播的过程。
在一种可能的实现方式中,根据总损失更新分类模型参数,得到训练好的分类模型之后还包括:得到训练好的分类模型;将待预测样本输入到训练好的分类模型,得到待预测样本属于每个类别的预测概率分布;将预测概率高于预设阈值的类别作为待预测样本的标签。
第二方面,本申请提供一种分类模型训练装置,包括训练单元和处理单元;训练单元将训练样本输入分类模型,得到训练样本属于每个类别的预测概率分布,预测概率分布为分类模型预测的训练样本属于每个类别的预测概率;处理单元根据预测概率分布与训练样本的实际标签分布,计算目标损失,实际标签分布为训练样本属于每个类别的实际概率,目标损失用于指示分类模型预测概率分布与实际标签分布之间的误差;根据预测概率分布中负类的预测概率分布,计算惩罚项,惩罚项用于指示预测概率分布中负类的离散程度;将目标损失与惩罚项之和记为总损失,根据总损失更新分类模型参数,得到训练好的分类模型。
第三方面,本申请提供一种神经网络处理器,所述神经网络处理器包括用于实现如第一方面或第一方面任意可能的实现方式中所述的方法。
第四方面,本申请提供一种计算设备,包括处理器和存储器;所述存储器用于存储指令,所述处理器用于执行所述指令,当所述处理器执行所述指令时,所述计算设备执行如第一方面或第一方面任意可能的实现方式中所述的方法。
第五方面,本申请一种计算机存储介质,计算机存储介质存储有计算机程序,计算机程序被处理器执行时实现如第一方面或第一方面任意可能的实现方式中所述的方法。
本方案使用负类预测概率分布的离散程度作为惩罚项计算新型损失函数,在后续的梯度下降与反向传播中,向减少损失的方向更新分类模型参数,使训练好的模型在预测分类结果时,负类预测结果分布更加均匀,正负类预测结果差别更加明显,减少模型误判,提高模型的预测准确性。
附图说明
图1为本申请实施例提供的一种分类模型训练方法流程图;
图2为本申请实施例提供的一种实体多分类模型示意图;
图3为本申请实施例提供的一种分类模型训练装置的结构示意图;
图4为本申请实施例提供的一种神经网络模型的结构框图;
图5是本申请实施例提供的一种神经网络处理器的结构框图;
图6为本申请实施例提供的一种服务器的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在此本申请说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本申请。如在本申请说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
首先介绍本申请的应用场景。分类任务有多分类和多标签分类两种,多分类任务中一条数据只有一个标签。如判定某个人的性别,只能归类为"男性"、"女性"其中一个,判断一个文本的情感只能归类为"正面"、"中面"或者"负面"其中一个。多标签分类任务中一条数据可能有多个标签。例如,一篇新闻可能同时归类为"娱乐"和"运动",也可能只属于"娱乐"或者其它类别。分类任务的对象可以是实体、图像和语音等。
在分类任务的模型训练中,使用损失函数进行模型的参数更新。损失函数用来估量模型的预测值与真实值的差异,若损失函数很小,表明模型预测结果与数据真实分布很接近,则模型性能良好;若损失函数很大,表明模型预测结果与数据真实分布差别较大,则模型性能不佳。根据损失函数值进行反向传播,更新模型参数。
实践中一般使用交叉熵损失函数作为损失函数。在分类任务中,训练样本有确定的类别,训练样本所属的正确类别为正类,其余类别为负类。而交叉熵损失函数只关注到了正类的预测损失而忽略了负类分布的情况。
以实体多分类任务为例,实体多分类通常是指把文本形式的实体根据一系列的特征划分到指定类别。近些年来,随着知识图谱技术的发展,实体多分类技术被大量用于知识图谱的构建中,旨在对图谱中的实体进行分类。一般短文本中的实体多分类任务是将实体按不同类型进行划分,包括地点,人物,作品,食物、交通工具,虚拟事物等24种不同类型。每个实体的类别仅有一个,通过分类模型预测实体属于每一个类别的概率,将概率最大的类别作为实体的分类类型。
以训练样本做6分类为例,实际标签分布为[0,1,0,0,0,0],若模型得到的两种预测结果分别为:
P1:[0.05,0.50,0.30,0.10,0.05,0.10];
P2:[0.10,0.50,0.10,0.10,0.10,0.10]。
由于标签分布中的负类标签值为0,在计算时将所有负类的损失计算都归于0,P1和P2两种预测结果的正类预测结果相同,因此两种预测结果的交叉熵损失是一样的。但出于减少模型误判的考虑,预测结果中正负类概率的差别更明显的预测效果更好,因此希望模型训练后得到P2的预测结果。因此,如何在模型训练的过程中,使模型的预测结果中负类的预测结果分布更加均匀,是一个亟待解决的问题。
为了解决上述问题,本方案提出一种基于新型损失函数的分类模型训练方法,在交叉熵损失函数的基础上,通过引入负类概率分布的离散程度作为惩罚项,提升损失函数的性能,提高模型的预测能力。
下面介绍本申请的具体方案,参见图1。
S101、将训练样本输入分类模型,输出训练样本属于每个类别的预测概率分布。
其中,预测概率分布为分类模型预测的训练样本属于每个类别的预测概率。分类模型为需要预训练的神经网络模型,将训练样本输入分类模型,分类模型能够将训练样本编码为特征向量,对特征向量进行提取特征数据,根据特征数据对训练样本进行分类,输出训练样本属于的每个类别的预测概率分布,得到对训练样本所属类别的预测结果。
预测概率分布为分类模型预测训练样本属于每个类别的预测概率,如对训练样本做三分类,判断训练样本属于三种类别中的一种或多种。训练样本经过分类模型输出属于三种类别的预测概率分布为[0.5,0.3,0.1],表示分类模型预测训练样本属于第一种类别的概率为0.5,预测训练样本属于第二种类别的概率为0.3,预测训练样本属于第三种类别的概率为0.1。
其中,训练样本可以为图像、文本或语音中的任意一种。例如,图像分类的任务可以是对于一张给定的图像,预测图像中的动物属于哪个类别标签,如判断图像属于“猫”、“狗”或“兔子”中的一种或几种类别。文本分类又称为实体分类,实体分类的任务是将对于给定的文本实体,预测其属于哪个分类标签,如判断实体“苹果”属于“水果”、“人物”或“品牌”中的一种或几种类别。语音分类的任务,是对于一段给定的语音数据,预测其属于哪个分类标签,如判断语音属于“钻孔机声”、“车鸣笛声”或“狗吠声”中的一种或几种类别。
示例性的,图2为本申请提供的一种实体多分类模型,包括BERT模型、全连接层和Softmax。训练样本为标注文本,使用训练样本对实体多分类模型进行预训练。在训练样本前添加特殊标志符[CLS],将[CLS]与训练样本转换为词向量输入BERT模型。经过BERT模型处理后输出每个字符所对应的输出向量,每个字符对应的输出向量都包含了上下文语境以及字符本身的信息,特殊标志符[CLS]对应的输出向量作为整个训练样本的语义表示。
选取特殊标志符[CLS]对应的输出向量与训练样本中实体对应的输出向量进行均值拼接后输入全连接层,对全连接层的输出使用Softmax做分类,得到实体属于每一个类别的预测概率分布。
S102、根据预测概率分布与训练样本的实际标签分布,计算目标损失。
其中,实际标签分布为训练样本属于每个类别的实际概率,目标损失用于指示分类模型预测概率分布与实际标签分布之间的误差,损失值越小表示预测的结果越好。。
训练样本的实际标签分布为训练样本属于每个类别的实际概率,如对训练样本做三分类,判断训练样本属于三种类别中的一种或多种。训练样本的实际分类属于第一种类别,则实际标签分布为[1,0,0],表示训练样本实际属于第一种类别的概率为1,训练样本实际属于第二种类别的概率为0,训练样本实际属于第三种类别的概率为0。
在一种可能的实现方式中,目标损失可以为交叉熵损失。获取训练样本的预测概率分布与实际标签分布后,能够根据公式(1)计算交叉熵损失LCE
Figure BDA0003048813950000051
其中,预测概率分布为Q,实际标签分布为P,分类标签数为N,P(i)表示训练样本属于第i个类别的预测概率,Q(i)表示训练样本属于第i个类别的实际概率。
S103、根据预测概率分布中负类的预测概率分布,计算惩罚项。
其中,惩罚项用于指示预测概率分布中负类的离散程度。
在分类任务中,训练样本有确定的类别,训练样本所属的实际类别为正类,用该训练样本进行训练时,除该训练样本所属的实际类别之外的其余类别为负类。如对训练样本做三分类,训练样本的实际类别属于第一种类别,则实际标签分布为[1,0,0],即实际标签分布中概率为1的类别为正类,其余类别为负类;训练样本经过分类模型得到预测概率分布[0.5,0.3,0.1],则正类对应的预测概率为0.5,负类的预测概率分布为[0.3,0.1]。
在分类任务的预测中,为了减少模型的误判,应当使预测概率分布中正负类预测概率的差值更大,本申请实施例引入负类的预测概率分布的离散程度作为惩罚项。其中,惩罚项可以为方差,标准差、极差或其他能够描述概率分布离散程度的指标。极差为负类的预测概率分布中最大预测概率与最小预测概率之差。
在一种可能的实现方式中,当惩罚项是负类的预测概率分布的方差时,对于预测概率分布Q,分类标签数为N,正类标签数为k(1≤k≤N)的训练任务,假设正类的位置为idx,则移除预测概率分布Q中idx位置的预测值之后得到分布Q′,Q′为负类的预测概率分布。根据公式(2)和公式(3)计算Q′的均值μ和方差σ2
Figure BDA0003048813950000052
Figure BDA0003048813950000053
则根据公式(4)惩罚项Lvar为:
Lvar=α(N-k)σ2 (4)
其中,α为超参数,用来调节惩罚项的比例,本申请实施例对α的值不做具体限定,优选的,α取值范围为[0.8,1.3]。
在一种可能的实现方式中,当惩罚项是负类的预测概率分布的方差时,对于预测概率分布Q,分类标签数为N,假设正类的位置为idx,则移除概率分布Q中idx位置的预测值之后得到分布Q′,Q′为负类的预测概率分布,Q′(max)为负类预测概率分布中的最大概率,Q′(min)为负类预测概率分布中的最小概率,计算极差R:
R=Q′(max)-Q′(min) (5)
则惩罚项Lvar为:
Lvar=αR (6)
其中,α为超参数,用来调节惩罚项的比例,本申请实施例对α的值不做具体限定,优选的,α取值范围为[0.8,1.3]。
S104、将目标损失与惩罚项之和记为总损失,根据总损失更新分类模型参数。本申请实施例中,总损失包括目标损失和惩罚项,根据上述方法得到目标损失和惩罚项,能够计算得到总损失,根据总损失进行反向传播,更新分类模型的权重参数,使模型达到收敛状态。分类模型的参数更新向着减小总损失的方向进行,即减小目标损失和惩罚项,减小目标损失使得正类预测概率更准确;引入负类预测概率分布的离散程度作为惩罚项,减小惩罚项使得负类的预测概率分布更均匀。
目标损失为LCE,惩罚项为Lvar,将目标损失与惩罚项之和记为总损失,根据公式(7)计算总损失Ltotal为:
Ltotal=LCE+Lvar (7)
分类模型的训练过程需要经过前向传播、反向传播和参数的更新不断迭代等过程,直至收敛(可以根据总损失函数的值是否不再下降或者趋于稳定来判断是否收敛),从而获得训练好的分类模型。前向传播是从分类模型的第一层向最后一层传播的过程,是分类模型对训练样本进行预测的过程,反向传播是从分类模型的最后一层向第一层传播的过程。在反向传播的过程中,主要计算分类模型中各个网络层的梯度值,并根据各层中的梯度值进行参数(比如权重、偏置等)的更新,通过反向传播来更新网络层中参数,使分类模型的输出值不断接近目标值,这样,经过多次的迭代训练,最终使分类模型收敛。
以实体6分类任务为例,若实际标签分布为[0,1,0,0,0,0],即第二种类别为实体的正类,其余类别为负类,分类模型得到的两种预测概率分布分别为:
P1:[0.05,0.50,0.30,0.10,0.05,0.10];
P2:[0.10,0.50,0.10,0.10,0.10,0.10]。
预测概率分布P1表示分类模型预测实体属于第一种类别的概率为0.05,预测实体属于第二种类别的概率为0.50,预测实体属于第三种类别的概率为0.30,预测实体属于第四种类别的概率为0.10,预测实体属于第五种类别的概率为0.05,预测实体属于第六种类别的概率为0.10。P1存在两个较大的概率值分别为正类预测概率值为0.50与负类预测概率值0.30,正负类的预测概率值差别不大。
预测概率分布P2表示分类模型预测实体属于第二种类别的概率为0.50,预测实体属于剩下五种类别的概率均为0.10。P2只有一个较大的概率值为正类预测概率值0.5,正负类预测概率值差别明显,负类的预测概率分布更均匀。
在正类0.5预测概率值相同的情况下,对于负类的预测概率分布越均匀,正负类的界限就更加明显,模型更容易正确判断实体所属于的类别。
在一种可能的实现方式中,以目标损失和方差惩罚项计算总损失时。预测概率分布P1去掉正类的预测概率值之后的负类预测概率分布为[0.05,0.3,0.1,0.05,0.1],根据公式(1)计算得到目标损失为0.3,根据公式(2)计算得到负类的预测值的均值为0.12,根据公式(3)计算得到方差为0.0086,在α取1.0的情况下,根据公式(4)计算得到惩罚为0.043,根据公式(7)计算得到总损失为0.343。预测概率分布P2去掉正类预测概率值之后的分布为[0.10,0.10,0.10,0.10,0.10],根据公式(1)计算得到目标损失为0.3,根据公式(2)计算得到负类的预测值的均值为0.10,根据公式(3)计算得到方差为0,在α取1.0的情况下,根据公式(4)计算得到惩罚为0,根据公式(7)计算得到总损失为0.3。
在一种可能的实现方式中,以目标损失和极差惩罚项计算总损失时。预测概率分布P1去掉正类的预测概率值之后的负类预测概率分布为[0.05,0.3,0.1,0.05,0.1],根据公式(1)计算得到目标损失为0.3,根据公式(5)计算得到负类预测的极差为0.295,在α取1.0的情况下,根据公式(6)计算得到惩罚为0.295,根据公式(7)计算得到总损失为0.595。预测概率分布P2去掉正类的预测概率值之后的分布为[0.10,0.10,0.10,0.10,0.10],根据公式(1)计算得到目标损失为0.3,根据公式(5)计算得到负类预测的极差为0,在α取1.0的情况下,根据公式(6)计算得到惩罚为0,根据公式(7)计算得到总损失为0.3。
在两种不同的预测概率分布P1和P2下,交叉熵损失由于只关注模型对正类预测概率的准确性,负类概率预测分布不同的两种结果得到的损失值也是一样的,使得模型训练的过程中,无法达到预测结果中负类预测概率分布均匀,正负类概率差别明显的目的。而本方案中使用以负类预测概率分布的离散程度作为惩罚项的新型损失函数来进行模型训练,模型朝着减少损失值的方向进行参数更新,即朝着减小惩罚项的方向,使得结果中负类预测结果的分布更均匀,增加正负类预测概率差别,减少模型的误判,使预测结果更准确。
以实体6分类任务为例,若实际标签分布为[0,1,0,1,0,0],即第二种类别和第四种类别为实体的正类,其余类别为负类,分类模型得到的两种预测概率分布分别为:
P1:[0.05,0.50,0.30,0.50,0.05,0.10];
P2:[0.10,0.50,0.10,0.50,0.10,0.10]。
预测概率分布P1表示模型预测实体属于第一种类别的概率为0.05,预测实体属于第二种类别的概率为0.50,预测实体属于第三种类别的概率为0.30,预测实体属于第四种类别的概率为0.50,预测实体属于第五种类别的概率为0.05,预测实体属于第六种类别的概率为0.10。
预测概率分布P2表示模型预测实体属于第一种类别的概率为0.10,预测实体属于第二种类别的概率为0.50,预测实体属于第三种类别的概率为0.10,预测实体属于第四种类别的概率为0.50,预测实体属于第五种类别的概率为0.10,预测实体属于第六种类别的概率为0.10。
在一种可能的实现方式中,以目标损失和方差惩罚项计算总损失函数。预测概率分布P1去掉正类的预测概率值之后的分布为[0.05,0.3,0.05,0.1],根据公式(1)计算得到目标损失为0.6,根据公式(2)计算得到负类预测的均值为0.125,根据公式(3)计算得到方差为0.0106,在α取1.0的情况下,根据公式(4)计算得到惩罚为0.043,根据公式(7)计算得到总损失为0.643。预测概率分布P2去掉正类的预测概率值之后的分布为[0.10,0.10,0.10,0.10],根据公式(1)计算得到目标损失为0.6,根据公式(2)计算得到负类预测的均值为0.10,根据公式(3)计算得到方差为0,在α取1.0的情况下,根据公式(4)计算得到惩罚为0,根据公式(7)计算得到总损失为0.6。
在一种可能的实现方式中,以目标损失和极差惩罚项计算总损失函数。预测概率分布P1去掉正类的预测概率值之后的分布为[0.05,0.3,0.05,0.1],根据公式(1)计算得到目标损失为0.6,根据公式(5)计算得到负类预测的极差为0.295,在α取1.0的情况下,根据公式(6)计算得到惩罚为0.295,根据公式(7)计算得到总损失为0.895。预测概率分布P2去掉正类的预测概率值之后的分布为[0.10,0.10,0.10,0.10],根据公式(1)计算得到目标损失为0.6,根据公式(5)计算得到负类预测的极差为0,在α取1.0的情况下,根据公式(6)计算得到惩罚为0,根据公式(7)计算得到总损失为0.6。
在两种不同的预测概率分布P1和P2下,以交叉熵损失计算模型预测误差的目标损失由于只关注模型对正类预测概率的准确性,负类概率预测分布不同的两种结果得到的损失值也是一样的,使得模型训练的过程中,无法达到预测结果中负类预测概率分布均匀,正负类概率差别明显的目的。而本方案中使用以负类预测概率分布的离散程度作为惩罚项的新型损失函数来进行模型训练,模型朝着减少损失值的方向进行参数更新,即朝着减小惩罚项的方向,使得结果中负类预测结果的分布更均匀,增加正负类预测概率差别,减少模型的误判,使预测结果更准确。在训练完成后,使用分类模型对待预测样本进行分类。以文本样本做实体分类为例,将待预测实体输入训练好的分类模型,输出分类预测结果。在待预测实体前添加特殊标志符[CLS]输入BERT模型,经过BERT模型处理后输出每个字符所对应的输出向量,选取特殊标志符[CLS]对应的输出向量与训练样本中实体对应的输出向量进行均值拼接后输入全连接层,对全连接层的输出使用Softmax做分类,得到待预测实体属于每一个类别的预测概率分布。在单一标签的多分类任务中,将最大概率对应的类别作为实体的类型,输出分类预测结果。在多标签分类任务中,将高于预设阈值的预测概率对应的类别作为实体的类型,输出分类预测结果。
本方案使用负类预测概率分布的离散程度作为惩罚项结合损失函数计算模型预测的总损失,根据总损失进行反向传播,向减少总损失的方向更新分类模型参数,使训练好的模型在预测时,预测概率分布中负类的预测概率分布更加均匀,正负类预测概率的差值更大,减少模型误判,提高模型的预测准确性。
使用本方案在自然语言处理的中文短文本实体24分类任务中进行了测试,在交叉验证实验中,超参数取固定值(α=1.2),使用本方案得到的模型普遍效果优于使用交叉熵。相较于只用交叉熵损失来训练模型,本方案使用的新型损失函数的F1 score提高了0.28个百分点,有效提高了模型的预测能力。
可选的,本申请中的分类模型还可以是图像分类模型或语音分类模型。
下面介绍本申请实施例中的一种分类模型训练装置,参见图3,图3为本申请实施例提供的一种分类模型训练装置,包括训练单元310和处理单元320。
训练单元310将训练样本输入分类模型,得到训练样本属于每个类别的预测概率分布,预测概率分布为分类模型预测的训练样本属于每个类别的预测概率。
处理单元320根据预测概率分布与训练样本的实际标签分布,计算目标损失,实际标签分布为训练样本属于每个类别的实际概率,目标损失用于指示分类模型预测概率分布与实际标签分布之间的误差;根据预测概率分布中负类的预测概率分布,计算惩罚项,惩罚项用于指示预测概率分布中负类的离散程度;将目标损失与惩罚项之和记为总损失,根据总损失更新分类模型参数,得到训练好的分类模型。
分类模型为神经网络模型,请参阅图4,图4是本申请实施例提供的神经网络模型的结构框图。应当理解的是,图4只是示意性示出了一种可能的结构,不应理解为唯一结构。如图4所示,神经网络模型400可以包括输入层410,卷积层/池化层420,其中池化层为可选的,以及神经网络层430。
下面详细描述卷积层/池化层420的结构。
如图4所示卷积层/池化层420可以包括如示例421-426层,在一种实现方式中,421层为卷积层,422层为池化层,423层为卷积层,424层为池化层,425为卷积层,426为池化层;在另一种实现方式中,421、422为卷积层,423为池化层,424、425为卷积层,426为池化层。即卷积层的输出可以作为随后的池化层的输入,也可以作为另一个卷积层的输入以继续进行卷积操作。
以卷积层421为例,卷积层421可以包括很多个卷积算子,卷积算子也称为核,其在图像处理中的作用相当于一个从输入图像矩阵中提取特定信息的过滤器,卷积算子本质上可以是一个权重矩阵,这个权重矩阵通常被预先定义,在对图像进行卷积操作的过程中,权重矩阵通常在输入图像上沿着水平方向一个像素接着一个像素(或两个像素接着两个像素,取决于步长的取值)的进行处理,从而完成从图像中提取特定特征的工作。该权重矩阵的大小应该与图像的大小相关。需要注意的是,权重矩阵的纵深维度和输入图像的纵深维度是相同的,在进行卷积运算的过程中,权重矩阵会延伸到输入图像的整个深度。因此,和一个单一的权重矩阵进行卷积会产生一个单一纵深维度的卷积化输出,但是大多数情况下不使用单一权重矩阵,而是应用维度相同的多个权重矩阵。每个权重矩阵的输出被堆叠起来形成卷积图像的纵深维度。不同的权重矩阵可以用来提取图像中不同的特征,例如一个权重矩阵用来提取图像边缘信息,另一个权重矩阵用来提取图像的特定颜色,又一个权重矩阵用来对图像中不需要的噪点进行模糊化,该多个权重矩阵维度相同,经过该多个维度相同的权重矩阵提取后的特征图维度也相同,再将提取到的多个维度相同的特征图合并形成卷积运算的输出。这些权重矩阵中的权重值在实际应用中需要经过大量的训练得到,通过训练得到的权重值形成的各个权重矩阵可以从输入图像中提取信息,从而帮助神经网络模型400进行正确的预测。
当神经网络模型400有多个卷积层的时候,初始的卷积层(例如421)往往提取较多的一般特征,该一般特征也可以称之为低级别的特征;随着神经网络模型400深度的加深,越往后的卷积层(例如426)提取到的特征越来越复杂,比如高级别的语义之类的特征,语义越高的特征越适用于待解决的问题。
由于常常需要减少训练参数的数量,因此卷积层之后常常需要周期性的引入池化层,即如图4中420所示例的421-426各层,可以是一层卷积层后面跟一层池化层,也可以是多层卷积层后面接一层或多层池化层。在图像处理过程中,池化层的唯一目的就是减少图像的空间大小。池化层可以包括平均池化算子和/或最大池化算子,以用于对输入图像进行采样得到较小尺寸的图像。平均池化算子可以在特定范围内对图像中的像素值进行计算产生平均值。最大池化算子可以在特定范围内取该范围内值最大的像素作为最大池化的结果。另外,就像卷积层中用权重矩阵的大小应该与图像大小相关一样,池化层中的运算符也应该与图像的大小相关。通过池化层处理后输出的图像尺寸可以小于输入池化层的图像的尺寸,池化层输出的图像中每个像素点表示输入池化层的图像的对应子区域的平均值或最大值。
下面详细描述神经网络层430的结构。
在经过卷积层/池化层420的处理后,神经网络模型400还不足以输出所需要的输出信息。因为如前,卷积层/池化层420只会提取特征,并减少输入图像带来的参数。然而为了生成最终的输出信息(所需要的类信息或别的相关信息),神经网络模型400需要利用神经网络层430来生成一个或者一组所需要的类的数量的输出。因此,在神经网络层430中可以包括多层隐含层(如图4所示的431、432至433)以及输出层440,该多层隐含层中所包含的参数可以根据具体的任务类型的相关训练数据进行预先训练得到,例如该任务类型可以包括图像识别,图像分类,图像超分辨率重建等等。应当理解的是,图4所示的三个隐含层1至3仅为示例性,在其他实施方式中可能包括不同数量的隐含层。
在神经网络层430中的多层隐含层之后,也就是整个神经网络模型400的最后层为输出层440,该输出层440具有类似分类交叉熵的损失函数,具体用于计算预测误差,一旦整个神经网络模型400的前向传播(如图4由410至440的传播为前向传播)完成,反向传播(如图4由440至410的传播为反向传播)就会开始更新前面提到的各层的权重值以及偏差,以减少神经网络模型400的损失及神经网络模型400通过输出层输出的结果和理想结果之间的误差。需要说明的是,如图4所示的神经网络模型400仅作为一种神经网络模型的示例,在具体的应用中,神经网络模型还可以以其他网络模型的形式存在,
请参阅图5,图5是本申请实施例提供的神经网络处理器的结构框图。如图5所示,神经网络处理器50的核心部分为运算电路503,控制器504控制运算电路503提取存储器(权重存储器或输入存储器)中的数据并进行运算。在一些实现方式中,运算电路503内部包括多个处理单元(Process Engine,PE)。在一些实现方式中,运算电路503是二维脉动阵列。运算电路503还可以是一维脉动阵列或者能够执行例如乘法和加法这样的数学运算的其它电子线路。在一些实现方式中,运算电路503是通用的矩阵处理器。
举例来说,假设有输入矩阵A,权重矩阵B,输出矩阵C。运算电路503从权重存储器502中取矩阵B相应的数据,并缓存在运算电路503中每一个PE上。运算电路503从输入存储器501中取矩阵A数据与矩阵B进行矩阵运算,得到的矩阵的部分结果或最终结果,保存在累加器508中。向量计算单元507可以对运算电路503的输出做进一步处理,如向量乘,向量加,指数运算,对数运算,大小比较等等。例如,向量计算单元507可以用于神经网络中非卷积/非FC层的网络计算,如池化(Pooling),批归一化(Batch Normalization),局部响应归一化(Local Response Normalization)等。在一些实现方式中,向量计算单元507将经处理的输出的向量存储到统一缓存器506。例如,向量计算单元507可以将非线性函数应用到运算电路503的输出,例如累加值的向量,用以生成激活值。在一些实现方式中,向量计算单元507生成归一化的值、合并值,或二者均有。在一些实现方式中,处理过的输出的向量能够用作到运算电路503的激活输入,例如用于在神经网络中的后续层中的使用。
请参阅图5,统一存储器506用于存放输入数据以及输出数据。存储单元访问控制器505(Direct Memory Access Controller,DMAC)将外部存储器中的输入数据搬运到输入存储器501和/或统一存储器506、将外部存储器中的权重数据存入权重存储器502,以及将统一存储器506中的数据存入外部存储器。总线接口单元(Bus Interface Unit,BIU)510用于通过总线实现主CPU、DMAC和取指存储器505之间进行交互。与控制器504连接的取指存储器(instruction fetch buffer)505用于存储控制器504使用的指令;控制器504用于调用指存储器505中缓存的指令,实现控制该运算加速器的工作过程。
一般地,统一存储器506,输入存储器501,权重存储器502以及取指存储器505均为片上(On-Chip)存储器,外部存储器为该NPU外部的存储器,该外部存储器可以为双倍数据率同步动态随机存储器(Double Data Rate Synchronous Dynamic Random AccessMemory,简称DDR SDRAM)、高带宽存储器(High Bandwidth Memory,HBM)或其他可读可写的存储器。
图6是本申请实施例提供的一种计算设备的结构示意图,计算设备600包括用于实现实体多分类模型训练方法的操作的模块,包括:一个或者多个处理器610、通信接口620以及存储器630。可选的,所述处理器610、通信接口620以及存储器630通过总线640相互连接,其中,
所述处理器610用于执行上述图1中S101-S103中所执行的步骤,在此不再赘述。
处理器610可以有多种具体实现形式,例如处理器610可以为中央处理器或图像处理器,处理器610还可以是单核处理器或多核处理器,处理器610还可以由CPU和硬件芯片的组合。
通信接口620可以为有线接口或无线接口,用于与其他模块或设备进行通信,有线接口可以是以太接口、局域互联网络(local interconnect network,LIN)等,无线接口可以是蜂窝网络接口或使用无线局域网接口等。
存储器630可以是非易失性存储器,例如,只读存储器(read-only memory,ROM)、可编程只读存储器(programmable ROM,PROM)、可擦除可编程只读存储器(erasable PROM,EPROM)、电可擦除可编程只读存储器(electrically EPROM,EEPROM)或闪存。存储器630也可以是易失性存储器,易失性存储器可以是随机存取存储器(random access memory,RAM),其用作外部高速缓存。
存储器630也可用于存储指令和数据,以便于处理器610调用存储器630中存储的指令实现上述S101-S103中执行的操作。此外,计算设备600可能包含相比于图6展示的更多或者更少的组件,或者有不同的组件配置方式。
总线640可以是外设部件互连标准(peripheral component interconnect,PCI)总线或扩展工业标准结构(extended industry standard architecture,简称EISA)总线等。所述总线640可以分为地址总线、数据总线、控制总线等。为便于表示,图6中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
可选地,该计算设备600还可以包括输入/输出接口650,输入/输出接口650连接有输入/输出设备,用于接收输入的信息,输出操作结果。
本申请实施例还提供一种非瞬态计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,当计算机程序在处理器上运行时,可以实现上述方法实施例中执行的方法步骤,所述计算机存储介质的处理器在执行上述方法步骤的具体实现可参照上述方法实施例中S101-S102的具体操作,在此不再赘述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及方法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,上述描述的装置、电子设备和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的装置、电子设备和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口、装置或单元的间接耦合或通信连接,也可以是电的,机械的或其它的形式连接。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以是两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分,或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-OnlyMemory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以权利要求的保护范围为准。

Claims (10)

1.一种分类模型训练方法,其特征在于,包括:
将训练样本输入分类模型,得到所述训练样本属于每个类别的预测概率分布,所述预测概率分布为所述分类模型预测的所述训练样本属于每个类别的预测概率;
根据所述预测概率分布与所述训练样本的实际标签分布,计算目标损失,所述实际标签分布为所述训练样本属于每个类别的实际概率,所述目标损失用于指示所述分类模型的预测概率分布与实际标签分布之间的误差;
根据所述预测概率分布中负类的预测概率分布,计算惩罚项,所述惩罚项用于指示所述预测概率分布中负类的离散程度;
将所述目标损失与所述惩罚项之和记为总损失,根据所述总损失更新所述分类模型参数,得到训练好的分类模型。
2.根据权利要求1所述的方法,其特征在于,所述根据所述预测概率分布中负类的预测概率分布,计算惩罚项,包括:
获取所述预测概率分布中负类的预测概率分布;
根据所述负类的预测概率分布计算所述负类的预测概率分布的方差;
根据分类标签数、正类标签数以及所述负类的预测概率分布的方差,确定所述惩罚项。
3.根据权利要求1所述的方法,其特征在于,所述根据所述预测概率分布中负类的预测概率分布,计算惩罚项,包括:
获取所述预测概率分布中负类的预测概率分布;
根据所述负类的预测概率分布计算所述负类的预测概率分布的极差,所述极差为所述负类的预测概率分布中最大预测概率与最小预测概率之差;
根据所述负类的预测概率分布的极差,确定所述惩罚项。
4.根据权利要求2或3所述的方法,其特征在于,所述根据所述预测概率分布与所述训练样本的所述实际标签分布,计算目标损失包括:
根据预测概率分布中第i个类别的预测概率与所述第i个类别的实际概率,计算交叉熵损失作为所述目标损失,其中,i的取值为1到N,N为分类标签数。
5.根据权利要求1所述的方法,其特征在于,所述根据所述总损失更新所述分类模型参数,得到训练好的分类模型包括:
基于所述总损失进行反向传播,计算所述分类模型中多个网络层的梯度;
基于所述多个网络层的梯度,对所述多个网络层进行参数更新,得到所述训练好的分类模型。
6.根据权利要求5所述的方法,其特征在于,所述根据所述总损失更新所述分类模型参数,得到训练好的分类模型之后,还包括:
将待预测样本输入到所述训练好的分类模型,得到所述待预测样本属于每个类别的预测概率分布;
将预测概率高于预设阈值的类别作为所述待预测样本的标签。
7.一种分类模型训练装置,其特征在于,包括训练单元和处理单元:
所述训练单元,用于将训练样本输入分类模型,得到所述训练样本属于每个类别的预测概率分布,所述预测概率分布为所述分类模型预测的所述训练样本属于每个类别的预测概率;
所述处理单元,用于根据所述预测概率分布与所述训练样本的实际标签分布,计算目标损失,所述实际标签分布为所述训练样本属于每个类别的实际概率,所述目标损失用于指示所述分类模型预测概率分布与实际标签分布之间的误差;
根据所述预测概率分布中负类的预测概率分布,计算惩罚项,所述惩罚项用于指示所述预测概率分布中负类的离散程度;
将所述目标损失与所述惩罚项之和记为总损失,根据所述总损失更新所述分类模型参数。
8.一种神经网络处理器,其特征在于,所述神经网络处理器包括用于实现权利要求1至6任一项所述的方法。
9.一种计算设备,其特征在于,包括处理器和存储器;所述存储器用于存储指令,所述处理器用于执行所述指令,当所述处理器执行所述指令时,所述计算设备执行如权利要求1至6任一项所述的方法。
10.一种计算机存储介质,其特征在于,所述计算机存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至6任一项所述的方法。
CN202110481964.2A 2021-04-30 2021-04-30 分类模型训练方法、装置、计算设备及存储介质 Pending CN113011532A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110481964.2A CN113011532A (zh) 2021-04-30 2021-04-30 分类模型训练方法、装置、计算设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110481964.2A CN113011532A (zh) 2021-04-30 2021-04-30 分类模型训练方法、装置、计算设备及存储介质

Publications (1)

Publication Number Publication Date
CN113011532A true CN113011532A (zh) 2021-06-22

Family

ID=76380524

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110481964.2A Pending CN113011532A (zh) 2021-04-30 2021-04-30 分类模型训练方法、装置、计算设备及存储介质

Country Status (1)

Country Link
CN (1) CN113011532A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113822371A (zh) * 2021-09-30 2021-12-21 支付宝(杭州)信息技术有限公司 训练分组模型,以及对时序数据进行分组的方法和装置
CN115630689A (zh) * 2022-12-21 2023-01-20 苏州大学 优化文本分类模型输出层激活函数的方法、设备和系统

Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107273458A (zh) * 2017-06-01 2017-10-20 百度在线网络技术(北京)有限公司 深度模型训练方法及装置、图像检索方法及装置
CN109409318A (zh) * 2018-11-07 2019-03-01 四川大学 统计模型的训练方法、统计方法、装置及存储介质
CN109543821A (zh) * 2018-11-26 2019-03-29 济南浪潮高新科技投资发展有限公司 一种限制权重分布提高量化效果的卷积神经网络训练方法
CN109902722A (zh) * 2019-01-28 2019-06-18 北京奇艺世纪科技有限公司 分类器、神经网络模型训练方法、数据处理设备及介质
CN110503616A (zh) * 2019-08-28 2019-11-26 上海海事大学 一种应用于图片去噪的生成式网络
CN111177507A (zh) * 2019-12-31 2020-05-19 支付宝(杭州)信息技术有限公司 多标记业务处理的方法及装置
CN111553399A (zh) * 2020-04-21 2020-08-18 佳都新太科技股份有限公司 特征模型训练方法、装置、设备及存储介质
CN111680698A (zh) * 2020-04-21 2020-09-18 北京三快在线科技有限公司 图像识别方法、装置及图像识别模型的训练方法、装置
CN111914944A (zh) * 2020-08-18 2020-11-10 中国科学院自动化研究所 基于动态样本选择和损失一致性的物体检测方法和系统
CN112465017A (zh) * 2020-11-26 2021-03-09 平安科技(深圳)有限公司 分类模型训练方法、装置、终端及存储介质

Patent Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107273458A (zh) * 2017-06-01 2017-10-20 百度在线网络技术(北京)有限公司 深度模型训练方法及装置、图像检索方法及装置
CN109409318A (zh) * 2018-11-07 2019-03-01 四川大学 统计模型的训练方法、统计方法、装置及存储介质
CN109543821A (zh) * 2018-11-26 2019-03-29 济南浪潮高新科技投资发展有限公司 一种限制权重分布提高量化效果的卷积神经网络训练方法
CN109902722A (zh) * 2019-01-28 2019-06-18 北京奇艺世纪科技有限公司 分类器、神经网络模型训练方法、数据处理设备及介质
CN110503616A (zh) * 2019-08-28 2019-11-26 上海海事大学 一种应用于图片去噪的生成式网络
CN111177507A (zh) * 2019-12-31 2020-05-19 支付宝(杭州)信息技术有限公司 多标记业务处理的方法及装置
CN111553399A (zh) * 2020-04-21 2020-08-18 佳都新太科技股份有限公司 特征模型训练方法、装置、设备及存储介质
CN111680698A (zh) * 2020-04-21 2020-09-18 北京三快在线科技有限公司 图像识别方法、装置及图像识别模型的训练方法、装置
CN111914944A (zh) * 2020-08-18 2020-11-10 中国科学院自动化研究所 基于动态样本选择和损失一致性的物体检测方法和系统
CN112465017A (zh) * 2020-11-26 2021-03-09 平安科技(深圳)有限公司 分类模型训练方法、装置、终端及存储介质

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113822371A (zh) * 2021-09-30 2021-12-21 支付宝(杭州)信息技术有限公司 训练分组模型,以及对时序数据进行分组的方法和装置
CN115630689A (zh) * 2022-12-21 2023-01-20 苏州大学 优化文本分类模型输出层激活函数的方法、设备和系统

Similar Documents

Publication Publication Date Title
CN110020592B (zh) 物体检测模型训练方法、装置、计算机设备及存储介质
CN108073902B (zh) 基于深度学习的视频总结方法、装置及终端设备
CN111814810A (zh) 图像识别方法、装置、电子设备及存储介质
CN111026544B (zh) 图网络模型的节点分类方法、装置及终端设备
CN109902716B (zh) 一种对齐分类模型的训练方法和图像分类方法
CN110929836B (zh) 神经网络训练及图像处理方法和装置、电子设备、介质
CN111583911B (zh) 基于标签平滑的语音识别方法、装置、终端及介质
CN113128671B (zh) 一种基于多模态机器学习的服务需求动态预测方法及系统
CN111428557A (zh) 基于神经网络模型的手写签名的自动校验的方法和装置
KR102250728B1 (ko) 샘플 처리 방법, 장치, 기기 및 저장 매체
CN109726291B (zh) 分类模型的损失函数优化方法、装置及样本分类方法
CN111105017A (zh) 神经网络量化方法、装置及电子设备
CN113011532A (zh) 分类模型训练方法、装置、计算设备及存储介质
CN112749737A (zh) 图像分类方法及装置、电子设备、存储介质
EP4343616A1 (en) Image classification method, model training method, device, storage medium, and computer program
CN114419378B (zh) 图像分类的方法、装置、电子设备及介质
CN111062440A (zh) 一种样本选择方法、装置、设备及存储介质
CN113449840A (zh) 神经网络训练方法及装置、图像分类的方法及装置
CN113239697B (zh) 实体识别模型训练方法、装置、计算机设备及存储介质
CN110717407A (zh) 基于唇语密码的人脸识别方法、装置及存储介质
CN113902944A (zh) 模型的训练及场景识别方法、装置、设备及介质
CN113762005A (zh) 特征选择模型的训练、对象分类方法、装置、设备及介质
CN112818946A (zh) 年龄识别模型的训练、年龄识别方法、装置及电子设备
CN116109907B (zh) 目标检测方法、装置、电子设备及存储介质
CN109657710B (zh) 数据筛选方法、装置、服务器及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination