CN112465017A - 分类模型训练方法、装置、终端及存储介质 - Google Patents

分类模型训练方法、装置、终端及存储介质 Download PDF

Info

Publication number
CN112465017A
CN112465017A CN202011348555.7A CN202011348555A CN112465017A CN 112465017 A CN112465017 A CN 112465017A CN 202011348555 A CN202011348555 A CN 202011348555A CN 112465017 A CN112465017 A CN 112465017A
Authority
CN
China
Prior art keywords
prediction
loss function
task
function value
classification model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011348555.7A
Other languages
English (en)
Inventor
吴天博
王健宗
程宁
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202011348555.7A priority Critical patent/CN112465017A/zh
Publication of CN112465017A publication Critical patent/CN112465017A/zh
Priority to PCT/CN2021/083844 priority patent/WO2021208722A1/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种分类模型训练方法、装置、终端及存储介质,该方法在分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务,将样本输入分类模型后,经第一任务预测得到第一预测概率,经第二任务预测得到第二预测概率;再利用第一预测概率计算第一损失函数值,利用第二预测概率计算第二损失函数值,以及利用第一预测概率和第二预测概率计算第一任务和第二任务的关联损失函数值;基于第一损失函数值、第二损失函数值、关联损失函数值计算分类模型的最终损失函数值,再反向传播更新分类模型。本发明通过上述方式提前告知分类模型在进行预测时需要避免的非期望预测类别,提高预测的准确率,同时也降低预测错误时造成的影响。

Description

分类模型训练方法、装置、终端及存储介质
技术领域
本申请涉及人工智能技术领域,特别是涉及一种分类模型训练方法、装置、终端及存储介质。
背景技术
随着技术的发展,人工智能(Artificial Intelligence,AI)领域的发展日新月异,特别是随着深度学习技术的广泛应用,其在物体检测、识别等领域取得了突破性的进展。目前,在图像识别、语音识别、声纹识别等领域,会采用基础的分类网络进行训练得到特征,再进行进一步的分类,通过分类来识别出输入数据所属的人或者语音内容等,而基于该种方式训练得到的分类模型,其已经达到了较高的准确率。但是,在现实生产应用中,往往存在一些特殊情况,某一类别被误分类为另一类别时可能会导致很严重的影响,例如:正面情绪被误分类为负面情绪、取消预约被误分类为预约意图、胃着凉被误分类为胃癌等。
针对于此类问题,目前通用做法仅仅是在损失函数上直接进行改进,对其中某种预测结果的出现给予更高的惩罚因子,比如focal loss会给予更高的损失权重;或者是,在分类模型预测结果为A的基础上,将样本再重新做一次分类。但是,上述方式并不能有效解决上述问题,第一种方式只能缓解上述分类问题,并不能体现出预测类别互斥的概念;第二种方式只能针对于分类规则简单的问题,当分类规则复杂时,会导致模型的复杂程度直线上升,而且该方式并不能让模型有效的学习到重要特征。
发明内容
本申请提供一种分类模型训练方法、装置、终端及存储介质,以解决现有分类模型无法有效地提前排除掉部分非期望分类结果的问题。
为解决上述技术问题,本申请采用的一个技术方案是:提供一种分类模型训练方法,包括:在待训练的分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务;将预先准备好的样本输入至分类模型,经第一任务预测得到各预测类别对应的第一预测概率,经第二任务预测得到各预测类别对应的第二预测概率;利用每个第一预测概率计算第一任务的第一损失函数值,同时利用每个第二预测概率计算第二任务的第二损失函数值,以及利用每个第一预测概率和每个第二预测概率计算第一任务和第二任务的关联损失函数值;基于第一损失函数值、第二损失函数值、关联损失函数值计算分类模型的最终损失函数值;根据最终损失函数值反向传播更新分类模型。
作为本申请的进一步改进,利用第一预测概率计算第一任务的第一损失函数值,包括:构建期望预测类别对应的正确标签向量,正确标签向量包括每个预测类别的第一标签值,期望预测类别对应的第一标签值为1,其余预测类别对应的第一标签值为0;将正确标签向量和第一预测概率输入至第一预设损失函数中计算得到第一损失函数值。
作为本申请的进一步改进,利用第二预测概率计算第二任务的第二损失函数值,包括:构建非期望结果对应的错误标签向量,错误标签向量包括每个预测类别的第二标签值,非期望预测类别对应的第二标签值为1,其余预测类别对应的第二标签值为0;将错误标签向量和第二预测概率输入至第二预设损失函数计算得到每个预测类别的初始损失函数值;按照预设处理规则处理每个初始损失函数值,得到第二损失函数值。
作为本申请的进一步改进,按照预设处理规则处理初始损失函数值,得到第二损失函数值,包括:获取错误标签向量中每个预测类别对应的第二标签值;逐个利用每个初始损失函数值乘以目标数值后再累加,得到第二损失函数值,在每个初始损失函数值乘以目标数值时,目标数值按照预设概率取第二标签值,否则取1;第二损失函数值的计算公式为:
Figure BDA0002800604220000021
其中,p2表示第二预测概率,l2表示第二标签值,L′(p2,l2)表示第二损失函数值,
Figure BDA0002800604220000022
表示第i个预测类别对应的第二预测概率,
Figure BDA0002800604220000031
表示第i个预测类别对应的第二标签值,
Figure BDA0002800604220000032
表示初始损失函数值,p表示预设概率,
Figure BDA0002800604220000033
表示以预设概率p取
Figure BDA0002800604220000034
否则取1,n表示预测类别的数量。
作为本申请的进一步改进,当使用训练好的分类模型进行预测时,第二任务对应的错误标签向量的每个第二标签值均取1。
作为本申请的进一步改进,利用每个第一预测概率和每个第二预测概率计算第一任务和第二任务的关联损失函数值的计算公式为:
Figure BDA0002800604220000035
其中,p1表示第一预测概率,p2表示第二预测概率,L(p1,p2)表示关联损失函数值,
Figure BDA0002800604220000036
表示第i个预测类别对应的第一预测概率,
Figure BDA0002800604220000037
表示第i个预测类别对应的第二预测概率,n表示预测类别的数量。
作为本申请的进一步改进,样本包括多个历史文本数据,分类模型训练好之后,用于实现对文本进行分类预测。。
为解决上述技术问题,本申请采用的另一个技术方案是:提供一种分类模型训练装置,包括:构建模块,用于在待训练的分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务;预测模块,用于将预先准备好的样本输入至分类模型,经第一任务预测得到各预测类别对应的第一预测概率,经第二任务预测得到各预测类别对应的第二预测概率;第一计算模块,用于利用每个第一预测概率计算第一任务的第一损失函数值,同时利用每个第二预测概率计算第二任务的第二损失函数值,以及利用每个第一预测概率和每个第二预测概率计算第一任务和第二任务的关联损失函数值;第二计算模块,用于基于第一损失函数值、第二损失函数值、关联损失函数值计算分类模型的最终损失函数值;训练模块,用于根据最终损失函数值反向传播更新分类模型。
为解决上述技术问题,本申请采用的再一个技术方案是:提供一种终端,该终端包括处理器、与处理器耦接的存储器,其中,存储器存储有用于实现上述分类模型训练方法的程序指令;处理器用于执行存储器存储的程序指令以基于多任务来训练分类模型。
为解决上述技术问题,本申请采用的再一个技术方案是:提供一种存储介质,存储有能够实现上述分类模型训练方法的程序文件。
本申请的有益效果是:本申请的分类模型训练方法通过在分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务,当对分类模型进行训练时,计算出第一任务的第一损失函数值和第二任务的第二损失函数值,并且,通过第一任务得到的第一预测概率和第二任务得到的第二预测概率,计算出第一任务和第二任务的关联损失函数值,利用关联损失函数值来强调第二任务的意图,尽可能增大第一任务和第二任务之间预测值的差异性,最后根据第一损失函数值、第二损失函数值、关联损失函数值计算得到最终损失函数值,再利用该最终损失函数值反向传播更新分类模型,其通过针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务进行训练,相当于提前告知分类模型那些预测类别不能作为该样本的输出,使得分类模型能够提前排除掉会造成严重错误或影响部分预测结果,其一方面能够起到缩小预测结果范围的目的,使得预测结果更为准确,另一方面避免出现对实际情况会产生严重影响的预测结果。
附图说明
图1是本发明实施例的分类模型训练方法的流程示意图;
图2是本发明实施例的分类模型训练装置的功能模块示意图;
图3是本发明实施例的终端的结构示意图;
图4是本发明实施例的存储介质的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请的一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请中的术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”、“第三”的特征可以明示或者隐含地包括至少一个该特征。本申请的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。本申请实施例中所有方向性指示(诸如上、下、左、右、前、后……)仅用于解释在某一特定姿态(如附图所示)下各部件之间的相对位置关系、运动情况等,如果该特定姿态发生改变时,则该方向性指示也相应地随之改变。此外,术语“包括”和“具有”以及它们任何变形,意图在于覆盖不排他的包含。例如包含了一系列步骤或单元的过程、方法、系统、产品或设备没有限定于已列出的步骤或单元,而是可选地还包括没有列出的步骤或单元,或可选地还包括对于这些过程、方法、产品或设备固有的其它步骤或单元。
在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
图1是本发明实施例的分类模型训练方法的流程示意图。需注意的是,若有实质上相同的结果,本发明的方法并不以图1所示的流程顺序为限。如图1所示,该方法包括步骤:
步骤S101:在待训练的分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务。
本实施例中所述的分类模型优选为BERT模型,需要说明的是,该分类模型不局限于BERT模型,其他适用于文本分类的模型都可以采用本实施例所要求保护的分类模型训练方法进行训练。
在步骤S101中,针对于待训练的分类模型,首先在分类模型中构建第一任务和第二任务,其中,第一任务针对于期望预测类别,该期望预测类别即用户希望输出的预测结果,第二任务针对于非期望预测类别,该非期望预测类别即用户希望不能输出的预测结果,例如当一个分类模型用于辨识生物,预测类别包括老人、小孩、成年人和猫,当该生物为老人时,则输出的预测类别不能是猫,该分类模型的期望预测类别为老人,非期望预测类别为猫。
需要理解的是,本实施例中,第一任务是本分类模型的主要任务,用来输出最终的预测结果,而第二任务是第一任务的辅助任务,用来告知第一任务需要避免掉的非期望预测类别,因此,第二任务作为底层任务,第一任务作为高层任务。
步骤S102:将预先准备好的样本输入至分类模型,经第一任务预测得到各预测类别对应的第一预测概率,经第二任务预测得到各预测类别对应的第二预测概率。
需要理解的是,分类问题通常可以划分为二分类问题、多分类问题、多标签分类问题。其中,二分类问题表示分类任务中有两个类别,比如识别一幅图片是不是猫,则输出结果只有是猫和不是猫两种情况,即二分类是假设每个样本都被设置了一个且仅有一个标签0或者1;多分类问题表示分类任务中有多个类别,比如对一堆水果图片分类,每张样本上的水果可能是橘子、苹果、梨等,而多分类则是假设每个样本都被设置了一个且仅有一个标签,即一个样本上的图片可以是苹果或者梨,但是同时不可能同时是苹果和梨;多标签分类是指给每个样本一系列的目标标签,可以想象成一个数据点的各属性不是相互排斥的(例如,一个水果既是苹果又是梨就是相互排斥的),比如一个文档记录的话题,该文档可能被认为是同时涉及到了饮食、健康、金融或者教育相关的话题。
本实施例中,第一任务的目的是要输出一个期望预测类别,属于多分类问题,而第二任务中,非期望预测类别可以为多个,即最终的输出结果可能涉及到多个预测类别,因此,第二任务属于多标签分类问题。
优选地,在构建该第一任务和第二任务时,该第一任务以softmax函数作为激活函数进行构建,该第二任务以sigmoid函数作为激活函数进行构建。
进一步的,针对于第一任务,优选采用多类别交叉熵损失函数,而针对于第二任务,优选采用二分类交叉熵损失函数。
在步骤S102中,在构建好第一任务和第二任务之后,将样本输入值该分类模型后,在第一任务中通过softmax函数预测得到各预测类别的第一预测概率,在第二任务中通过sigmoid函数预测得到各预测类别的第二预测概率。其中,softmax函数和sigmoid函数均是很成熟的技术,此处不再赘述。
步骤S103:利用每个第一预测概率计算第一任务的第一损失函数值,同时利用每个第二预测概率计算第二任务的第二损失函数值,以及利用每个第一预测概率和每个第二预测概率计算第一任务和第二任务的关联损失函数值。
在步骤S103中,第一任务采用softmax函数作为激活函数,则第一任务通过多类别交叉熵损失函数来计算第一损失函数值。第二任务采用sigmoid函数作为激活函数,则第二任务通过二分类交叉熵损失函数来计算第二损失函数值。
本实施例中,在将样本输入至分类模型,利用第一任务和第二任务进行预测之前,需要先为样本构建正确标签向量和错误标签向量,该正确标签向量用于表示第一任务的期望预测类别,该错误标签向量用于表示第二任务的非期望预测类别。具体地,该步骤S103具体包括:
1、利用第一预测概率计算第一任务的第一损失函数值,具体包括:
1.1构建期望预测类别对应的正确标签向量,正确标签向量包括每个预测类别的第一标签值,期望预测类别对应的第一标签值为1,其余预测类别对应的第一标签值为0。
具体地,为了方便计算,在构建正确标签向量时,每个预测类别对应的第一标签值按照预测类别的排列顺序排列。例如,假设当前有A、B、C、D、E、F六种预测类别,针对于样本a,其第一任务的期望预测类别为A类,因此,该样本a的正确标签向量中,A类对应的第一标签值为1,B、C、D、E、F类对应的第一标签值为0,即该正确标签向量为[1,0,0,0,0,0]。
1.2将正确标签向量和第一预测概率输入至第一预设损失函数中计算得到第一损失函数值。
具体地,该第一预设损失函数的计算公式为:
Figure BDA0002800604220000081
其中,p1表示第一预测概率,l1表示第一标签值,L(p1,l1)表示第一损失函数值,
Figure BDA0002800604220000082
表示第i个预测类别对应的第一标签值,
Figure BDA0002800604220000083
表示第i个预测类别对应的第一预测概率,n表示预测类别的数量。
2、利用第二预测概率计算第二任务的第二损失函数值,具体包括:
2.1构建非期望结果对应的错误标签向量,错误标签向量包括每个预测类别的第二标签值,非期望预测类别对应的第二标签值为1,其余预测类别对应的第二标签值为0。
具体地,每个预测类别对应的第一标签值按照预测类别的排列顺序排列。继续以上述例子进行描述,针对于样本a,其第二任务的非期望预测类别为D、E类,因此,该样本a的错误标签向量中,D、E类对应的第二标签值为1,A、B、C、F对应的第二标签值为0,即该错误标签向量为[0,0,0,1,1,0]。
2.2将错误标签向量和第二预测概率输入至第二预设损失函数计算得到每个预测类别的初始损失函数值。
具体地,该第二预设损失函数的计算公式为:
Figure BDA0002800604220000084
其中,p2表示第二预测概率,l2表示第二标签值,L(p2,l2)表示初始损失函数值,
Figure BDA0002800604220000085
表示第i个预测类别对应的第二预测概率,
Figure BDA0002800604220000086
表示第i个预测类别对应的第二标签值,n表示预测类别的数量。
2.3按照预设处理规则处理每个初始损失函数值,得到第二损失函数值。
具体地,在计算第二任务的损失函数值时,考虑到使用该分类模型进预测时,无法获知第二任务的错误标签向量,因此,在对分类模型进行训练时,需要减少这种差异性,从而提升训练的效果,保证最终预测结果的准确性,因此,该按照预设处理规则处理初始损失函数值,得到第二损失函数值的步骤,包括:
a.获取错误标签向量中每个预测类别对应的第二标签值。
具体地,在得到每个预测类别对应的初始损失函数值之后,获取每个预测类别对应的第二标签值,例如,当错误标签向量为[0,0,0,1,1,0]时,即A、B、C、F对应的第二标签值为0,D、E类对应的第二标签值为1。
b.逐个利用每个初始损失函数值乘以目标数值后再累加,得到第二损失函数值,在每个初始损失函数值乘以目标数值时,目标数值按照预设概率取第二标签值,否则取1。
具体地,本实施例中,借鉴dropout思想,通过预先设置预设概率p,在计算第二损失函数值时,根据该预设概率p选择乘以预测类别对应的第二标签值还是乘以1,具体参考该第二损失函数值的计算公式:
Figure BDA0002800604220000091
其中,p2表示第二预测概率,l2表示第二标签值,L′(p2,l2)表示第二损失函数值,
Figure BDA0002800604220000092
表示第i个预测类别对应的第二预测概率,
Figure BDA0002800604220000093
表示第i个预测类别对应的第二标签值,
Figure BDA0002800604220000094
表示初始损失函数值,p表示预设概率,
Figure BDA0002800604220000095
表示以预设概率p取
Figure BDA0002800604220000096
否则取1,n表示预测类别的数量。
优选地,该预设概率p优选为95%,经试验预设概率p为95%时,分类模型的训练结果较优。
进一步的,当使用训练好的分类模型进行预测时,第二任务对应的错误标签向量的每个第二标签值均取1。
本实施例中,通过设置预设概率p,则在训练过程中,也可能会出现第二标签值均为1的错误标签向量,其与使用分类模型预测时,错误标签向量的每个第二标签值均去1的情况相同,采用此种方式进行训练从而使得分类模型的训练过程与真实环境更为接近,起到了较好的训练效果。
3、利用每个第一预测概率和每个第二预测概率计算第一任务和第二任务的关联损失函数值。
其中,关联损失函数值的计算公式为:
Figure BDA0002800604220000097
其中,p1表示第一预测概率,p2表示第二预测概率,L(p1,p2)表示关联损失函数值,
Figure BDA0002800604220000101
表示第i个预测类别对应的第一预测概率,
Figure BDA0002800604220000102
表示第i个预测类别对应的第二预测概率,n表示预测类别的数量。
具体地,考虑到第一任务与第二任务之间的互斥关系,可能会导致网络学习不稳定,因此,通过设计上述关联损失函数,尽可能增大第一任务和第二任务预测值的差异性。具体参考上述关联损失函数值的计算公式:
Figure BDA0002800604220000103
表示第i个预测类别对应的第二预测概率,为一个(0,1)之间的数,当
Figure BDA0002800604220000104
越大时,
Figure BDA0002800604220000105
越小,根据log函数的性质,
Figure BDA0002800604220000106
越小则
Figure BDA0002800604220000107
值越大,而分类模型训练的目的就是使得损失函数值小,因此,为了使得关联损失函数值L(p1,p2)小,当
Figure BDA0002800604220000108
值越大时,则
Figure BDA0002800604220000109
的值应该越小,从而使得
Figure BDA00028006042200001010
Figure BDA00028006042200001011
的乘积越小,因此,通过上述分析可知,当
Figure BDA00028006042200001012
越大时,
Figure BDA00028006042200001013
的值应尽可能小,从而使得关联损失函数值L(p1,p2)小以达到训练的目的,从而
Figure BDA00028006042200001014
Figure BDA00028006042200001015
的差异性越大。
步骤S104:基于第一损失函数值、第二损失函数值、关联损失函数值计算分类模型的最终损失函数值。
具体地,该最终损失函数值的计算公式为:
Loss=L(p1,l1)+L′(p2,l2)+L(p1,p2);
其中,Loss表示最终损失函数值,L(p1,l1)表示第一损失函数值,L′(p2,l2)表示第二损失函数值,L(p1,p2)表示关联损失函数值。
步骤S105:根据最终损失函数值反向传播更新分类模型。
在步骤S105中,通过最终损失函数反向传播更新分类模型的方案内容已经很成熟,此处不再赘述。
本发明实施例的分类模型训练方法通过在分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务,当对分类模型进行训练时,计算出第一任务的第一损失函数值和第二任务的第二损失函数值,并且,通过第一任务得到的第一预测概率和第二任务得到的第二预测概率,计算出第一任务和第二任务的关联损失函数值,利用关联损失函数值来强调第二任务的意图,尽可能增大第一任务和第二任务之间预测值的差异性,最后根据第一损失函数值、第二损失函数值、关联损失函数值计算得到最终损失函数值,再利用该最终损失函数值反向传播更新分类模型,其通过针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务进行训练,相当于提前告知分类模型那些预测类别不能作为该样本的输出,使得分类模型能够提前排除掉会造成严重错误或影响部分预测结果,其一方面能够起到缩小预测结果范围的目的,使得预测结果更为准确,另一方面避免出现对实际情况会产生严重影响的预测结果。
进一步的,分类模型可以广泛应用文本识别、图片识别、音频识别等领域,本实施例中,优选地,样本包括多个历史文本数据,利用该多个历史文本数据对待训练的分类模型按照上述分类模型训练方法进行训练,当分类模型训练好之后,该分类模型可以用于实现对文本进行分类预测。例如,该分类模型为BERT(Bidirectional EncoderRepresentations from Transformer)模型时,采用上述分类模型训练方法进行训练后,该BERT模型在对文本识别的效果更好,识别结果更为准确。
进一步的,再根据最终损失函数值反向传播更新分类模型之后,还包括:将训练好的分类模型上传至区块链中。
具体地,基于训练好的分类模型得到对应的摘要信息,具体来说,摘要信息由训练好的分类模型进行散列处理得到,比如利用sha256s算法处理得到。将摘要信息上传至区块链可保证其安全性和对用户的公正透明性。用户设备可以从区块链中下载得该摘要信息,以便查证分类模型是否被篡改。本示例所指区块链是分布式数据存储、点对点传输、共识机制、加密算法等计算机技术的新型应用模式。区块链(Blockchain),本质上是一个去中心化的数据库,是一串使用密码学方法相关联产生的数据块,每一个数据块中包含了一批次网络交易的信息,用于验证其信息的有效性(防伪)和生成下一个区块。区块链可以包括区块链底层平台、平台产品服务层以及应用服务层等。
图2是本发明实施例的分类模型训练装置的功能模块示意图。如图2所示,该分类模型训练装置20包括构建模块21、预测模块22、第一计算模块23、第二计算模块24和训练模块25。
其中,构建模块21,用于在待训练的分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务;预测模块22,用于将预先准备好的样本输入至分类模型,经第一任务预测得到各预测类别对应的第一预测概率,经第二任务预测得到各预测类别对应的第二预测概率;第一计算模块23,用于利用每个第一预测概率计算第一任务的第一损失函数值,同时利用每个第二预测概率计算第二任务的第二损失函数值,以及利用每个第一预测概率和每个第二预测概率计算第一任务和第二任务的关联损失函数值;第二计算模块24,用于基于第一损失函数值、第二损失函数值、关联损失函数值计算分类模型的最终损失函数值;训练模块25,用于根据最终损失函数值反向传播更新分类模型。
可选地,在一些实施例中,第一计算模块23利用第一预测概率计算第一任务的第一损失函数值的操作还可以为:构建期望预测类别对应的正确标签向量,正确标签向量包括每个预测类别的第一标签值,期望预测类别对应的第一标签值为1,其余预测类别对应的第一标签值为0;将正确标签向量和第一预测概率输入至第一预设损失函数中计算得到第一损失函数值。
可选地,在一些实施例中,第一计算模块23利用第二预测概率计算第二任务的第二损失函数值的操作还可以为:构建非期望结果对应的错误标签向量,错误标签向量包括每个预测类别的第二标签值,非期望预测类别对应的第二标签值为1,其余预测类别对应的第二标签值为0;将错误标签向量和第二预测概率输入至第二预设损失函数计算得到每个预测类别的初始损失函数值;按照预设处理规则处理每个初始损失函数值,得到第二损失函数值。
可选地,在一些实施例中,第一计算模块23按照预设处理规则处理初始损失函数值,得到第二损失函数值的操作还可以为:获取错误标签向量中每个预测类别对应的第二标签值;逐个利用每个初始损失函数值乘以目标数值后再累加,得到第二损失函数值,在每个初始损失函数值乘以目标数值时,目标数值按照预设概率取第二标签值,否则取1;第二损失函数值的计算公式为:
Figure BDA0002800604220000121
Figure BDA0002800604220000131
其中,p2表示第二预测概率,l2表示第二标签值,L′(p2,l2)表示第二损失函数值,
Figure BDA0002800604220000132
表示第i个预测类别对应的第二预测概率,
Figure BDA0002800604220000133
表示第i个预测类别对应的第二标签值,
Figure BDA0002800604220000134
表示初始损失函数值,p表示预设概率,
Figure BDA0002800604220000135
表示以预设概率p取
Figure BDA0002800604220000136
否则取1,n表示预测类别的数量。
可选地,在一些实施例中,当使用训练好的分类模型进行预测时,第二任务对应的错误标签向量的每个第二标签值均取1。
可选地,在一些实施例中,第一计算模块23利用每个第一预测概率和每个第二预测概率计算第一任务和第二任务的关联损失函数值的计算公式为:
Figure BDA0002800604220000137
其中,p1表示第一预测概率,p2表示第二预测概率,L(p1,p2)表示关联损失函数值,
Figure BDA0002800604220000138
表示第i个预测类别对应的第一预测概率,
Figure BDA0002800604220000139
表示第i个预测类别对应的第二预测概率,n表示预测类别的数量。
可选地,在一些实施例中,样本包括多个历史文本数据,分类模型训练好之后,用于实现对文本进行分类预测。
关于上述实施例分类模型训练装置中各模块实现技术方案的其他细节,可参见上述实施例中的分类模型训练方法中的描述,此处不再赘述。
需要说明的是,本说明书中的各个实施例均采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似的部分互相参见即可。对于装置类实施例而言,由于其与方法实施例基本相似,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
请参阅图3,图3为本发明实施例的终端的结构示意图。如图3所示,该终端30包括处理器31及和处理器31耦接的存储器32。
存储器32存储有用于实现上述任一实施例所述的分类模型训练方法的程序指令。
处理器31用于执行存储器32存储的程序指令以基于多任务来训练分类模型。
其中,处理器31还可以称为CPU(Central Processing Unit,中央处理单元)。处理器31可能是一种集成电路芯片,具有信号的处理能力。处理器31还可以是通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现场可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
参阅图4,图4为本发明实施例的存储介质的结构示意图。本发明实施例的存储介质存储有能够实现上述所有方法的程序文件41,其中,该程序文件41可以以软件产品的形式存储在上述存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本申请各个实施方式所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质,或者是计算机、服务器、手机、平板等终端设备。
在本申请所提供的几个实施例中,应该理解到,所揭露的终端,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。以上仅为本申请的实施方式,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。

Claims (10)

1.一种分类模型训练方法,其特征在于,包括:
在待训练的分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务;
将预先准备好的样本输入至所述分类模型,经所述第一任务预测得到各预测类别对应的第一预测概率,经所述第二任务预测得到所述各预测类别对应的第二预测概率;
利用每个所述第一预测概率计算所述第一任务的第一损失函数值,同时利用每个所述第二预测概率计算所述第二任务的第二损失函数值,以及利用每个所述第一预测概率和每个所述第二预测概率计算所述第一任务和所述第二任务的关联损失函数值;
基于所述第一损失函数值、所述第二损失函数值、所述关联损失函数值计算所述分类模型的最终损失函数值;
根据所述最终损失函数值反向传播更新所述分类模型。
2.根据权利要求1所述的分类模型训练方法,其特征在于,所述利用所述第一预测概率计算所述第一任务的第一损失函数值,包括:
构建所述期望预测类别对应的正确标签向量,所述正确标签向量包括每个预测类别的第一标签值,所述期望预测类别对应的所述第一标签值为1,其余所述预测类别对应的所述第一标签值为0;
将所述正确标签向量和所述第一预测概率输入至第一预设损失函数中计算得到所述第一损失函数值。
3.根据权利要求1所述的分类模型训练方法,其特征在于,所述利用所述第二预测概率计算所述第二任务的第二损失函数值,包括:
构建所述非期望结果对应的错误标签向量,所述错误标签向量包括每个预测类别的第二标签值,所述非期望预测类别对应的所述第二标签值为1,其余所述预测类别对应的所述第二标签值为0;
将所述错误标签向量和所述第二预测概率输入至第二预设损失函数计算得到每个所述预测类别的初始损失函数值;
按照预设处理规则处理每个所述初始损失函数值,得到所述第二损失函数值。
4.根据权利要求3所述的分类模型训练方法,其特征在于,所述按照预设处理规则处理所述初始损失函数值,得到所述第二损失函数值,包括:
获取所述错误标签向量中每个所述预测类别对应的第二标签值;
逐个利用每个所述初始损失函数值乘以目标数值后再累加,得到所述第二损失函数值,在每个所述初始损失函数值乘以所述目标数值时,所述目标数值按照预设概率取所述第二标签值,否则取1;所述第二损失函数值的计算公式为:
Figure FDA0002800604210000021
其中,所述p2表示所述第二预测概率,所述l2表示所述第二标签值,所述L′(p2,l2)表示所述第二损失函数值,所述
Figure FDA0002800604210000022
表示第i个预测类别对应的第二预测概率,所述
Figure FDA0002800604210000023
表示第i个预测类别对应的所述第二标签值,所述
Figure FDA0002800604210000024
表示所述初始损失函数值,所述p表示预设概率,所述
Figure FDA0002800604210000025
表示以所述预设概率p取所述
Figure FDA0002800604210000026
否则取1,所述n表示所述预测类别的数量。
5.根据权利要求4所述的分类模型训练方法,其特征在于,当使用训练好的所述分类模型进行预测时,所述第二任务对应的所述错误标签向量的每个所述第二标签值均取1。
6.根据权利要求1所述的分类模型训练方法,其特征在于,所述利用每个所述第一预测概率和每个所述第二预测概率计算所述第一任务和所述第二任务的关联损失函数值的计算公式为:
Figure FDA0002800604210000027
其中,所述p1表示所述第一预测概率,所述p2表示所述第二预测概率,所述L(p1,p2)表示所述关联损失函数值,所述
Figure FDA0002800604210000028
表示第i个预测类别对应的第一预测概率,所述
Figure FDA0002800604210000029
表示第i个预测类别对应的第二预测概率,所述n表示所述预测类别的数量。
7.根据权利要求1所述的分类模型训练方法,其特征在于,所述样本包括多个历史文本数据,所述分类模型训练好之后,用于实现对文本进行分类预测。
8.一种分类模型训练装置,其特征在于,包括:
构建模块,用于在待训练的分类模型中构建针对于期望预测类别的第一任务和针对于非期望预测类别的第二任务;
预测模块,用于将预先准备好的样本输入至所述分类模型,经所述第一任务预测得到各预测类别对应的第一预测概率,经所述第二任务预测得到所述各预测类别对应的第二预测概率;
第一计算模块,用于利用每个所述第一预测概率计算所述第一任务的第一损失函数值,同时利用每个所述第二预测概率计算所述第二任务的第二损失函数值,以及利用每个所述第一预测概率和每个所述第二预测概率计算所述第一任务和所述第二任务的关联损失函数值;
第二计算模块,用于基于所述第一损失函数值、所述第二损失函数值、所述关联损失函数值计算所述分类模型的最终损失函数值;
训练模块,用于根据所述最终损失函数值反向传播更新所述分类模型。
9.一种终端,其特征在于,所述终端包括处理器、与所述处理器耦接的存储器,其中,
所述存储器存储有用于实现如权利要求1-7中任一项所述的分类模型训练方法的程序指令;
所述处理器用于执行所述存储器存储的所述程序指令以基于多任务来训练分类模型。
10.一种存储介质,其特征在于,存储有能够实现如权利要求1-7中任一项所述的分类模型训练方法的程序文件。
CN202011348555.7A 2020-11-26 2020-11-26 分类模型训练方法、装置、终端及存储介质 Pending CN112465017A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202011348555.7A CN112465017A (zh) 2020-11-26 2020-11-26 分类模型训练方法、装置、终端及存储介质
PCT/CN2021/083844 WO2021208722A1 (zh) 2020-11-26 2021-03-30 分类模型训练方法、装置、终端及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011348555.7A CN112465017A (zh) 2020-11-26 2020-11-26 分类模型训练方法、装置、终端及存储介质

Publications (1)

Publication Number Publication Date
CN112465017A true CN112465017A (zh) 2021-03-09

Family

ID=74808565

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011348555.7A Pending CN112465017A (zh) 2020-11-26 2020-11-26 分类模型训练方法、装置、终端及存储介质

Country Status (2)

Country Link
CN (1) CN112465017A (zh)
WO (1) WO2021208722A1 (zh)

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113011532A (zh) * 2021-04-30 2021-06-22 平安科技(深圳)有限公司 分类模型训练方法、装置、计算设备及存储介质
CN113065614A (zh) * 2021-06-01 2021-07-02 北京百度网讯科技有限公司 分类模型的训练方法和对目标对象进行分类的方法
WO2021208722A1 (zh) * 2020-11-26 2021-10-21 平安科技(深圳)有限公司 分类模型训练方法、装置、终端及存储介质
CN113657447A (zh) * 2021-07-14 2021-11-16 南京邮电大学 一种数据融合方法、装置、设备及存储介质
CN113887679A (zh) * 2021-12-08 2022-01-04 四川大学 融合后验概率校准的模型训练方法、装置、设备及介质
CN115630289A (zh) * 2022-12-21 2023-01-20 白杨时代(北京)科技有限公司 一种基于证据理论的目标识别方法及装置

Families Citing this family (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113822382B (zh) * 2021-11-22 2022-02-15 平安科技(深圳)有限公司 基于多模态特征表示的课程分类方法、装置、设备及介质
CN114417832B (zh) * 2021-12-08 2023-05-05 马上消费金融股份有限公司 消歧方法、消歧模型的训练方法及装置
CN114066105B (zh) * 2022-01-11 2022-09-27 浙江口碑网络技术有限公司 运单配送超时预估模型的训练方法,存储介质和电子设备
CN116304811B (zh) * 2023-02-28 2024-01-16 王宇轩 一种基于焦点损失函数动态样本权重调整方法及系统
CN117056836B (zh) * 2023-10-13 2023-12-12 腾讯科技(深圳)有限公司 程序分类模型的训练、程序类目识别方法及装置
CN117579399B (zh) * 2024-01-17 2024-05-14 北京智芯微电子科技有限公司 异常流量检测模型的训练方法和系统、异常流量检测方法

Family Cites Families (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170109651A1 (en) * 2015-10-20 2017-04-20 International Business Machines Corporation Annotating text using emotive content and machine learning
CN109711427A (zh) * 2018-11-19 2019-05-03 深圳市华尊科技股份有限公司 目标检测方法及相关产品
CN110163117B (zh) * 2019-04-28 2021-03-05 浙江大学 一种基于自激励判别性特征学习的行人重识别方法
CN110826614A (zh) * 2019-10-31 2020-02-21 合肥黎曼信息科技有限公司 一种构造逆标签及其损失函数的方法
CN111695596A (zh) * 2020-04-30 2020-09-22 华为技术有限公司 一种用于图像处理的神经网络以及相关设备
CN112465017A (zh) * 2020-11-26 2021-03-09 平安科技(深圳)有限公司 分类模型训练方法、装置、终端及存储介质

Cited By (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021208722A1 (zh) * 2020-11-26 2021-10-21 平安科技(深圳)有限公司 分类模型训练方法、装置、终端及存储介质
CN113011532A (zh) * 2021-04-30 2021-06-22 平安科技(深圳)有限公司 分类模型训练方法、装置、计算设备及存储介质
CN113065614A (zh) * 2021-06-01 2021-07-02 北京百度网讯科技有限公司 分类模型的训练方法和对目标对象进行分类的方法
CN113065614B (zh) * 2021-06-01 2021-08-31 北京百度网讯科技有限公司 分类模型的训练方法和对目标对象进行分类的方法
CN113657447A (zh) * 2021-07-14 2021-11-16 南京邮电大学 一种数据融合方法、装置、设备及存储介质
CN113887679A (zh) * 2021-12-08 2022-01-04 四川大学 融合后验概率校准的模型训练方法、装置、设备及介质
CN113887679B (zh) * 2021-12-08 2022-03-08 四川大学 融合后验概率校准的模型训练方法、装置、设备及介质
CN115630289A (zh) * 2022-12-21 2023-01-20 白杨时代(北京)科技有限公司 一种基于证据理论的目标识别方法及装置
CN115630289B (zh) * 2022-12-21 2023-09-26 白杨时代(北京)科技有限公司 一种基于证据理论的目标识别方法及装置

Also Published As

Publication number Publication date
WO2021208722A1 (zh) 2021-10-21

Similar Documents

Publication Publication Date Title
CN112465017A (zh) 分类模型训练方法、装置、终端及存储介质
CN109101537B (zh) 基于深度学习的多轮对话数据分类方法、装置和电子设备
CN112613308B (zh) 用户意图识别方法、装置、终端设备及存储介质
WO2021114840A1 (zh) 基于语义分析的评分方法、装置、终端设备及存储介质
CN108846077B (zh) 问答文本的语义匹配方法、装置、介质及电子设备
CN112164391B (zh) 语句处理方法、装置、电子设备及存储介质
US8180633B2 (en) Fast semantic extraction using a neural network architecture
CN112084383A (zh) 基于知识图谱的信息推荐方法、装置、设备及存储介质
CN111602128A (zh) 计算机实现的确定方法和系统
CN111680159A (zh) 数据处理方法、装置及电子设备
CN113627447A (zh) 标签识别方法、装置、计算机设备、存储介质及程序产品
CN114548101B (zh) 基于可回溯序列生成方法的事件检测方法和系统
CN113158687B (zh) 语义的消歧方法及装置、存储介质、电子装置
Wu et al. MARMOT: A deep learning framework for constructing multimodal representations for vision-and-language tasks
CN111339775A (zh) 命名实体识别方法、装置、终端设备及存储介质
CN110399472A (zh) 面试提问提示方法、装置、计算机设备及存储介质
CN112966517A (zh) 命名实体识别模型的训练方法、装置、设备及介质
CN111695335A (zh) 一种智能面试方法、装置及终端设备
CN114707041A (zh) 消息推荐方法、装置、计算机可读介质及电子设备
CN114238656A (zh) 基于强化学习的事理图谱补全方法及其相关设备
Lauren et al. A low-dimensional vector representation for words using an extreme learning machine
CN113868451A (zh) 基于上下文级联感知的社交网络跨模态对话方法及装置
Joty et al. Modeling speech acts in asynchronous conversations: A neural-CRF approach
WO2023116572A1 (zh) 一种词句生成方法及相关设备
CN108536666A (zh) 一种短文本信息提取方法和装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20210309

RJ01 Rejection of invention patent application after publication