CN113435525A - 分类网络训练方法、装置、计算机设备及存储介质 - Google Patents

分类网络训练方法、装置、计算机设备及存储介质 Download PDF

Info

Publication number
CN113435525A
CN113435525A CN202110745507.XA CN202110745507A CN113435525A CN 113435525 A CN113435525 A CN 113435525A CN 202110745507 A CN202110745507 A CN 202110745507A CN 113435525 A CN113435525 A CN 113435525A
Authority
CN
China
Prior art keywords
classification
training
sample set
classification network
samples
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110745507.XA
Other languages
English (en)
Other versions
CN113435525B (zh
Inventor
陈攀
刘莉红
刘玉宇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202110745507.XA priority Critical patent/CN113435525B/zh
Priority claimed from CN202110745507.XA external-priority patent/CN113435525B/zh
Publication of CN113435525A publication Critical patent/CN113435525A/zh
Application granted granted Critical
Publication of CN113435525B publication Critical patent/CN113435525B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2413Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on distances to training or reference patterns
    • G06F18/24133Distances to prototypes
    • G06F18/24137Distances to cluster centroïds
    • G06F18/2414Smoothing the distance, e.g. radial basis function networks [RBFN]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computational Linguistics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及人工智能技术领域,尤其涉及一种分类网络训练方法、装置、计算机设备及存储介质。该分类网络训练方法包括获取训练样本集;按照批处理量大小从训练样本集中随机抽取第一样本集;将每一第一训练样本分别输入至待训练的第一分类网络中以及第二分类网络中进行分类,获取每一第一训练样本对应的第一损失值以及第二损失值;分别对第一样本集对应的多个第一损失值以及多个第二损失值进行递增顺序排序,并选取第一损失值排在前N位的第一训练样本作为第二样本集,第二损失值排在前N位的第二训练样本作为第三样本集;通过第二样本集训练第二分类网络;通过第三样本集训练第一分类网络。该方法可有效降低错误分类样本对网络训练的影响。

Description

分类网络训练方法、装置、计算机设备及存储介质
技术领域
本发明涉及人工智能技术领域,尤其涉及一种分类网络训练方法、装置、计算机设备及存储介质。
背景技术
目前,由于深度学习网络的参数空间较大,具有非常强的泛化和拟合能力,因此当训练数据集中存在一定比例的错误标注数据,很容易会被网络学习,即在错误标注数据上拟合,从而影响模型鲁棒性。
在车损伤数据中,由于车损伤形态的千变万化以及损失程度等考虑因素较多,一般需要专业的定损专家才能保证标注数据的准确性,但是借助车损定损专家进行车损等级标注的代价太高,而借助普通的经过培训的标注人员进行车损等级标注,会使得到的训练数据集中存在部分错误标注数据,仍然无法保证模型鲁棒性,因此,如何使网络在存在部分错误标注数据的训练数据集上进行训练,且可同时保证模型鲁棒性已成为目前亟待解决的问题。
发明内容
本发明实施例提供一种分类网络训练方法、装置、计算机设备及存储介质,以解决网络在存在部分错误标注数据的训练数据集上进行训练,无法保证模型鲁棒性的问题。
一种分类网络训练方法,包括:
获取训练样本集;其中,所述训练样本集包括不同分类标签对应的正确分类样本以及错误分类样本;
按照批处理量大小从所述训练样本集中随机抽取第一样本集;其中,所述第一样本集对应多个第一训练样本;
将每一所述第一训练样本输入至待训练的第一分类网络中进行分类,获取所述每一所述第一训练样本对应的第一损失值;以及,将每一所述第一训练样本输入至待训练的第二分类网络中进行分类,获取每一所述第一训练样本对应的第二损失值;
对所述第一样本集对应的多个第一损失值进行递增顺序排序,并选取所述第一损失值排在前N位的所述第一训练样本作为第二样本集;以及,对所述所述第一样本集对应的多个第二损失值进行递增顺序排序,并选取所述第二损失值排在前N位的所述第一训练样本作为第三样本集;
通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络。
一种分类网络训练装置,包括:
样本集获取模块,用于获取训练样本集;其中,所述训练样本集包括不同分类标签对应的正确分类样本以及错误分类样本;
样本抽取模块,用于按照批处理量大小从所述训练样本集中随机抽取第一样本集;其中,所述第一样本集对应多个第一训练样本;
损失输出模块,用于将每一所述第一训练样本输入至待训练的第一分类网络中进行分类,获取所述每一所述第一训练样本对应的第一损失值;以及,将每一所述第一训练样本输入至待训练的第二分类网络中进行分类,获取每一所述第一训练样本对应的第二损失值;
样本过滤模块,用于对所述第一样本集对应的多个第一损失值进行递增顺序排序,并选取所述第一损失值排在前N位的所述第一训练样本作为第二样本集;以及,对所述所述第一样本集对应的多个第二损失值进行递增顺序排序,并选取所述第二损失值排在前N位的所述第一训练样本作为第三样本集;
联合训练模块,用于通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络。
一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述分类网络训练方法的步骤。
一种计算机存储介质,所述计算机存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述分类网络训练方法的步骤。
上述分类网络训练方法、装置、计算机设备及存储介质中,通过获取训练样本集,该训练样本集包括不同分类标签对应的正确分类样本以及错误分类样本,以使后续的模型训练可在该带有错误分类样本的训练样本集上进行训练,提升模型的鲁棒性和准确性。然后,按照批处理量大小从所述训练样本集中随机抽取第一样本集,以批量训练不同的样本集,通过将每一所述第一训练样本输入至待训练的第一分类网络中进行分类,获取所述每一所述第一训练样本对应的第一损失值;以及,将所述第一样本集输入至待训练的第二分类网络中进行分类,获取所述第一样本集对应的多个第二损失值,以根据该第一损失值和第二损失值对第一样本集进行过滤,得到低损失的第二样本集以及第三样本集,最后,通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络,以使第一分类网络以及第二分类网络互相学习错误过滤能力,通过不同分类网络间的低损失样本的交换,实现分类网络的联合交叉训练,从而有效降低错误分类样本对网络训练的影响。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本发明一实施例中分类网络训练方法的一应用环境示意图;
图2是本发明一实施例中分类网络训练方法的一流程图;
图3是本发明一实施例中多轮联合训练的框架图;
图4是本发明一实施例中分类网络训练方法的一流程图;
图5是图2中步骤S306的一具体流程图;
图6是本发明一实施例中分类网络训练方法的一流程图;
图7是图2中步骤S203的一具体流程图;
图8是本本发明一实施例中分类网络中的特征提取层的一具体结构图;
图9是图2中步骤S201的一具体流程图;
图10是本发明一实施例中分类网络训练装置的一示意图;
图11是本发明一实施例中计算机设备的一示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
该分类网络训练方法可应用在如图1的应用环境中,其中,计算机设备通过网络与服务器进行通信。计算机设备可以但不限于各种个人计算机、笔记本电脑、智能手机、平板电脑和便携式可穿戴设备。服务器可以用独立的服务器来实现。
在一实施例中,如图2所示,提供一种分类网络训练方法,以该方法应用在图1中的服务器为例进行说明,包括如下步骤:
S201:获取训练样本集,其中,训练样本集包括不同分类标签对应的正确分类样本以及错误分类样本。
其中,本申请中的分类网络训练方法可适用于不同分类任务的应用场景,例如车损分类场景。以下说明以车损分类的应用场景为例进行说明。该训练样本集中包括预先对车损图像中车辆损伤的形态按照严重程度进行分类,得到不同严重级别的分类标签。该车损图像以及对应的分类标签即可作为训练样本进行训练。
可以理解地是,由于由于车损形态是一个连续性的变化形态,需要考虑位置、深度、面积等诸多因素,因此在实际标注时,会出现一定比例的数据标签标记错误的情况,因此若要保证训练样本集中完全正确的标注,则需要更高级的定损人员和多轮次的质检,大大增加标注成本。故本实施例中,通过针对不同分类标签对应的正确分类样本进行错误标记(即错误分类样本),以使模型训练时可在该带有错误分类样本的训练样本集上进行训练,提升模型的鲁棒性和准确性。
S202:按照批处理量大小从训练样本集中随机抽取第一样本集,其中,第一样本集对应多个第一训练样本。
其中,批处理量大小(batch size)即指训练过程中不同批次训练中所需的样本量,即分批训练时每一轮训练所需的样本量(例如500)。该批处理量大小例可通过预先设置。具体地,按照该批处理量大小从所述训练样本集中随机抽取多个第一训练样本即得到第一样本集。
S203:将每一第一训练样本输入至待训练的第一分类网络中进行分类,获取每一第一训练样本对应的第一损失值;以及,将每一第一训练样本输入至待训练的第二分类网络中进行分类,获取每一第一训练样本对应的第二损失值。
具体地,该所述第一分类网络与第二分类网络的模型初始化参数不同,模型结构相同;所述第一分类网络以及所述第二分类网络均包括特征提取层以及分类层;如图7所示,所述特征提取层包括多个残差模块,每两个相邻的所述残差模块之间通过注意力机制模块连接,以使模型训练时中重点关注车损图像的全局特征信息。
其中,特征提取层即通过ResNet50实现,该ResNet50由多个残差模块构成,每一残差模块通过多层卷积层实现。需要说明的是,本实施例中的ResNet50特征骨干网络与传统的ResNet50特征骨干网络不同,即在传统的ResNet50特征骨干网络中引入了注意力机制模块,即每两个相邻的所述残差模块之间通过注意力机制模块连接,以以使模型训练时中重点关注车损图像的全局特征信息。
可以理解地是,通过将第一样本集中的每一第一训练样本输入至第一分类网络中进行分类,即可得到第一分类网络输出的该第一训练样本分类结果,基于分类结果与第一训练样本对应的真实分类标签计算对应的交叉熵损失或均方误差损失,即可得到第一损失值。针对本实施例中的应用场景为分类任务,故可采用交叉熵损失损失计算函数基于分类结果与第一训练样本对应的真实分类标签计算对应的第一损失值。其中,交叉熵损失计算函数包括
Figure BDA0003142578960000051
其中,y表示真实分类标签,
Figure BDA0003142578960000052
表示分类结果,n表示批处理量大小。
S204:对第一样本集对应的多个第一损失值进行递增顺序排序,并选取第一损失值排在前N位的所述第一训练样本作为第二样本集;以及,对第一样本集对应的多个第二损失值进行递增顺序排序,并选取第二损失值排在前N位的第一训练样本作为第三样本集。
其中,N为第二样本集以及第三样本集在第一样本集中所占比例,可通过预设设置或通过对所述第一样本集进行随机抽样,并根据得到的随机抽样集中正样本的占比估计该N值。
具体地,通过对多个第一损失值按照从小到大进行排序,将第一损失值排在前N位的第一训练样本作为第二样本集,以及对多个第二损失值按照从小到大进行排序,将第二损失值排在前N位的第一训练样本作为第三样本集,以将损失较大的样本过滤,保留损失较小的样本,以交叉训练分类网络。
S205:通过第二样本集训练第二分类网络;以及,通过第三样本集训练第一分类网络。
具体地,由于不同初始化参数的分类网络对数据的学习能力不同,当在包含错误分类样本的训练样本集上训练时,不同的分类网络对于错误分类样本有不同的过滤能力。故本实施例中,通过该第二样本集,即根据第一分类网络输出的第一损失值确定的低损失的第二样本集训练第二分类网络,以使第二分类网络学习第一分类网络的错误过滤能力;以及,通过该第三样本集,即根据第二分类网络输出的第二损失值确定的低损失的第三样本集训练第一分类网络,以使第一分类网络学习第二分类网络的错误过滤能力,实现不同分类网络间的低损失数据的交换,实现分类网络的联合交叉训练,从而有效降低错误分类样本对网络训练的影响。
如图3所示的多轮联合训练框架图,本实施例中的分类网络训练方法通过联合交叉训练的思想,即通过抽取第一分类网络A中的低损失的训练样本集(即第一样本集)训练第二分类网络B,然后通过抽取第二分类网络B中的低损失的训练样本集(即第二样本集)训练第一分类网络A,如此循环进行交叉训练,直至模型收敛,即可得到训练好的第一分类网络A以及第二分类网络B。
进一步地,在得到训练好的第一分类网络以及第二分类网络后,可选择任一网络对实际应用中车损图像进行分类,得到分类结果,或者将两个分类网络的分类结果做加权处理,得到整合后的分类结果,此处不做限定。
本实施例中,通过获取训练样本集,该训练样本集包括不同分类标签对应的正确分类样本以及错误分类样本,以使后续的模型训练可在该带有错误分类样本的训练样本集上进行训练,提升模型的鲁棒性和准确性。然后,按照批处理量大小从所述训练样本集中随机抽取第一样本集,以批量训练不同的样本集,通过将每一所述第一训练样本输入至待训练的第一分类网络中进行分类,获取所述每一所述第一训练样本对应的第一损失值;以及,将所述第一样本集输入至待训练的第二分类网络中进行分类,获取所述第一样本集对应的多个第二损失值,以根据该第一损失和第二损失对第一样本集进行过滤,得到低损失的第二样本集以及第三样本集,最后,通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络,以使第一分类网络以及第二分类网络互相学习错误过滤能力,通过不同分类网络间的低损失样本的交换,实现分类网络的联合交叉训练,从而有效降低错误分类样本对网络训练的影响。
在一实施例中,如图4所示,该分类网络训练方法还包括如下步骤:
S301:获取训练样本集。
其中,步骤S301与步骤S201的实现过程一致,为避免重复,此处不再赘述。
S302:按照批处理量大小从所述训练样本集中随机抽取第一样本集。
其中,步骤S302与步骤S202的实现过程一致,为避免重复,此处不再赘述。
S303:将每一第一训练样本输入至待训练的第一分类网络中进行分类,获取每一第一训练样本对应的第一损失值;以及,将每一第一训练样本输入至待训练的第二分类网络中进行分类,获取每一第一训练样本对应的第二损失值。
其中,步骤S303与步骤S203的实现过程一致,为避免重复,此处不再赘述。
S304:对第一样本集对应的多个第一损失值进行递增顺序排序,并选取第一损失值排在前N位的所述第一训练样本作为第二样本集;以及,对第一样本集对应的多个第二损失值进行递增顺序排序,并选取第二损失值排在前N位的第一训练样本作为第三样本集。
其中,步骤S304与步骤S204的实现过程一致,为避免重复,此处不再赘述。
S305:通过第二样本集训练第二分类网络;以及,通过第三样本集训练第一分类网络。
其中,步骤S305与步骤S205的实现过程一致,为避免重复,此处不再赘述。
S306:更新N值。
具体地,同一轮训练中该第二样本集以及第三样本集所占第一样本集的比例相同,即N值相同。不同轮训练中的N值可相同或不同,每轮训练中N值的确定可根据先验知识确定预设比例,以根据该预设比例更新每轮训练中的N值;或者,通过对所述第一样本集进行随机抽样,并根据得到的随机抽样集中正样本的占比估计所述N值。
S307:重复执行步骤S302-S306,直至模型收敛,获取训练好的第一分类网络以及第二分类网络。
具体地,通过重复执行S302-S306,直至模型收敛,即可获取训练好的第一分类网络以及第二分类网络。
在一实施例中,如图5所示,步骤S304中,步骤S306中,即更新损失抽样比例,具体包括如下步骤:
S401:按照预设比例更新N值;或者,
其中,预设比例为根据先验知识即经验值设定,根据该经验值设定每轮训练的N值。
S402:对第一样本集进行随机抽样,并根据得到的随机抽样集中正样本的占比估计N值。
具体地,通过对第一样本集进行随机抽样,得到一随机抽样集,然后将该随机抽样集中的正样本所占比例作为N值,或者在此基础上进一步减小N值,此处不做限定。
在一实施例中,如图6所示,该分类网络训练方法还包括如下步骤:
S501:获取训练样本集。
其中,步骤S501与步骤S201的实现过程一致,为避免重复,此处不再赘述。
S502:按照批处理量大小从训练样本集中随机抽取第一样本集。
其中,步骤S502与步骤S202的实现过程一致,为避免重复,此处不再赘述。
S503:将每一第一训练样本输入至待训练的第一分类网络中进行分类,获取每一第一训练样本对应的第一损失值;以及,将每一第一训练样本输入至待训练的第二分类网络中进行分类,获取每一第一训练样本对应的第二损失值。
其中,步骤S503与步骤S203的实现过程一致,为避免重复,此处不再赘述。
S504:对第一样本集对应的多个第一损失值进行递增顺序排序,并选取第一损失值排在前N位的所述第一训练样本作为第二样本集;以及,对第一样本集对应的多个第二损失值进行递增顺序排序,并选取第二损失值排在前N位的第一训练样本作为第三样本集。
其中,步骤S504与步骤S204的实现过程一致,为避免重复,此处不再赘述。
S505:通过第二样本集训练第二分类网络;以及,通过第三样本集训练第一分类网络。
其中,步骤S505与步骤S205的实现过程一致,为避免重复,此处不再赘述。
S506:更新批处理量大小;其中,更新后的批处理量大小小于更新前的批处理量大小。
可以理解地是,当网络准确时,损失较小的训练样本可认为均是正确分类样本,当训练样本集中存在错误分类文本时,深度学习网络在最初迭代过程中会首先学习正确的、简单的模式,因此,在深度学习网络的初始训练阶段,可通过损失来区分正确分类样本或者错误分类样本;而随着网络学习的深入,网络会逐渐在错误分类样本上拟合。为解决上述问题,本实施例中在网络的初始训练阶段设置一较大的batch size(即批处理量大小),然后在此基础上,通过逐步减小batch size,以在网络拟合错误分类样本之前将错误分类样本最大程度过滤掉。
具体地,通过在每轮训练中,将上一轮训练的批处理量大小减小一单位步长,即可获取本轮训练的批处理大小,实现对批处理量大小的更新。
S507:重复执行S502-S506,直至模型收敛,获取训练好的第一分类网络以及第二分类网络。
具体地,通过重复执行S502-S506,直至模型收敛,获取训练好的第一分类网络以及第二分类网络。
在一实施例中,如图7所示,步骤S303中,获取第一分类网络输出的第一样本集对应的多个第一损失值,包括:
S601:将每一第一训练样本输入至残差模块进行特征提取,得到残差模块的输出。
其中,残差模块由多个卷积层构成,具体结构与ResNet50中的每一残差模块的具体结构相同,此处不再详述。
具体地,将所述第一训练样本输入至所述残差模块进行特征提取,即经过一系列的卷积操作,即可得到残差模块的输出。
S602:将残差模块的输出输入至注意力机制模块中进行处理,得到注意力机制模块的输出。
其中,该注意力机制模块包括卷积层、分类层、归一化层以及激活层,具体结构如图8所示。该注意力机制模块中的处理过程包括:通过接收注意力机制模块的原始输入A,然后将原始输入A输入至1*1的卷积层进行卷积操作,得到卷积输出B,并将该卷积输出B输入至分类层(softmax)进行分类,得到分类输出C,将分类输出C与原始输入A拼接,并将拼接后的结果输入至1*1的卷积层卷积操作,得到卷积输出D,将该卷积输出D输入至归一化层进行标准化处理,得到标准化输出F,然后将该标准化输出F输入至激活层(ReLU)进行激活处理,得到激活输出G,然后将该激活输出G与原始输入A进行拼接,并将拼接后的结果作为注意力机制模块的输出。
可以理解地是,该注意力机制模块的输入可指上一残差模块基于第一训练样本进行特征提取得到或者上一残差模块基于上一注意力机制模块的输出进行特征提取得到。
S603:将注意力机制模块的输出输入至下一残差模块,以使下一残差模块对注意力机制模块的输出进行特征提取,得到残差模块的输出。
S604:重复执行步骤S602-S603,直至得到最后一个残差模块的输出。
具体地,通过将注意力机制模块的输出输入至下一残差模块,以使下一残差模块对注意力机制模块的输出进行特征提取,得到残差模块的输出,并重复执行步骤S602-S603,直至得到最后一个残差模块的输出。
S605:将最后一个残差模块的输出输入至分类器进行分类,得到预测分类结果。
具体地,将该最后一个残差模块的输出输入至所述分类器进行分类,即可得到预测分类结果。该分类器可采用目前开源的训练好的分类器,通过将该最后一个残差模块的输出,即特征提取层输出的图像特征输入至分类器,即可得到预测分类结果。
S606:基于预测分类结果与第一训练样本对应的真实分类标签,以得到每一第一训练样本对应的第一损失值。
具体地,通过采用损失函数基于预测分类结果与所述第一训练样本对应的真实分类标签,以计算每一所述第一训练样本对应的第一损失值,从而得到第一样本集对应的多个第一损失值。其中,损失函数包括但不限于交叉熵损失函数以及均方误差损失函数。
进一步地,由于第一分类网络与第二分类网络的模型结构相同,故针对步骤S303中,获取第二分类网络输出的所述第一样本集对应的多个第二损失值的具体实现过程与步骤S601-S606类似,为避免重复,此处不再赘述。
在一实施例中,如图9所示,步骤S201中,即获取训练样本集,具体包括如下步骤:
S701:获取每一分类标签对应的多个正确分类样本。
S702:按照预设错误标记比例对多个正确标记样本的分类标签进行错误标记,得到每一分类标签对应的多个错误分类样本。
其中,本实施例中的分类标签可根据车损严重程度进行划分,例如包括严重、中等、轻微等。通过获取预先根据该分类标签标注的正确分类样本,然后根据预设错误标记比例,随机对多个所述正确标记样本的分类标签进行错误标记,得到每一所述分类标签对应的多个错误分类样本,即针对不同的分类标签其对应的负样本比例相同,以保证负样本比例在训练样本集中是均匀分布的,保证样本均衡。
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
在一实施例中,提供一种分类网络训练装置,该分类网络训练装置与上述实施例中分类网络训练方法一一对应。如图10所示,该分类网络训练装置包括样本集获取模块10、样本抽取模块20、损失输出模块30、样本过滤模块40以及联合训练模块50。各功能模块详细说明如下:
样本集获取模块10,用于获取训练样本集;其中,所述训练样本集包括不同分类标签对应的正确分类样本以及错误分类样本;
样本抽取模块20,用于按照批处理量大小从所述训练样本集中随机抽取第一样本集;其中,所述第一样本集对应多个第一训练样本;
损失输出模块30,用于将每一第一训练样本输入至待训练的第一分类网络中进行分类,获取每一第一训练样本对应的第一损失值;以及,将每一第一训练样本输入至待训练的第二分类网络中进行分类,获取每一第一训练样本对应的第二损失值;
样本过滤模块40,用于对第一样本集对应的多个第一损失值进行递增顺序排序,并选取第一损失值排在前N位的所述第一训练样本作为第二样本集;以及,对第一样本集对应的多个第二损失值进行递增顺序排序,并选取第二损失值排在前N位的第一训练样本作为第三样本集;
联合训练模块50,用于通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络。
具体地,该分类网络训练装置还包括第一更新模块以及迭代训练模块。
第一更新模块,用于更新所述N值;
迭代训练模块,用于重复执行所述按照批处理量大小从所述训练样本集中随机抽取第一样本集的步骤,直至模型收敛,获取训练好的第一分类网络以及第二分类网络。
具体地,该第一更新模块包括第一更新单元以及第二更新单元。
第一更新单元,用于按照预设比例更新所述N值;或者,
第二更新单元,用于通过对所述第一样本集进行随机抽样,并根据得到的随机抽样集中正样本的占比估计所述损失抽样比例。
具体地,该分类网络训练装置还包括第二更新模块以及迭代训练模块。
第二更新模块,用于更新所述批处理量大小;其中,所述更新后的批处理量大小小于更新前的批处理量大小;
迭代训练模块,用于重复执行所述按照批处理量大小从所述训练样本集中随机抽取第一样本集的步骤,直至模型收敛,获取训练好的第一分类网络以及第二分类网络。
具体地,所述第一分类网络与第二分类网络的模型初始化参数不同,模型结构相同;所述第一分类网络以及所述第二分类网络均包括特征提取层以及分类层;所述特征提取层包括多个残差模块,每两个相邻的所述残差模块之间通过注意力机制模块连接。
具体地,该损失输出模块包括残差处理单元、注意力处理单元、级联处理单元、循环处理单元、分类单元以及损失计算单元。
残差处理单元,用于将每一所述第一训练样本输入至所述残差模块进行特征提取,得到残差模块的输出;
注意力处理单元,用于将所述残差模块的输出输入至所述注意力机制模块中进行处理,得到所述注意力机制模块的输出;
级联处理单元,用于将所述注意力机制模块的输出输入至下一残差模块,以使所述下一残差模块对所述注意力机制模块的输出进行特征提取,得到所述残差模块的输出;
循环处理单元,用于重复执行所述所述将所述残差模块的输出输入至所述注意力机制模块中进行处理,得到所述注意力机制模块的输出的步骤,直至得到最后一个残差模块的输出;
分类单元,用于将所述最后一个残差模块的输出输入至所述分类器进行分类,得到预测分类结果;
损失计算单元,用于基于所述预测分类结果与所述第一训练样本对应的真实分类标签,以得到每一所述第一训练样本对应的第一损失值。
具体地,样本集获取模块包括正确样本获取单元以及错误样本获取单元。
正确样本获取单元,用于获取每一所述分类标签对应的多个正确分类样本。
错误样本获取单元,用于按照预设错误标记比例对多个所述正确标记样本的分类标签进行错误标记,得到每一所述分类标签对应的多个错误分类样本。
关于分类网络训练装置的具体限定可以参见上文中对于分类网络训练方法的限定,在此不再赘述。上述分类网络训练装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图11所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括计算机存储介质、内存储器。该计算机存储介质存储有操作系统、计算机程序和数据库。该内存储器为计算机存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储执行分类网络训练方法过程中生成或获取的数据,如第一分类网络。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种分类网络训练方法。
在一个实施例中,提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现上述实施例中的分类网络训练方法的步骤,例如图2所示的步骤S201-S205,或者图4至图7、图9中所示的步骤。或者,处理器执行计算机程序时实现分类网络训练装置这一实施例中的各模块/单元的功能,例如图10所示的各模块/单元的功能,为避免重复,这里不再赘述。
在一实施例中,提供一计算机存储介质,该计算机存储介质上存储有计算机程序,该计算机程序被处理器执行时实现上述实施例中分类网络训练方法的步骤,例如图2所示的步骤S201-S205,或者图4至图7、图9中所示的步骤,为避免重复,这里不再赘述。或者,该计算机程序被处理器执行时实现上述分类网络训练装置这一实施例中的各模块/单元的功能,例如图10所示的各模块/单元的功能,为避免重复,这里不再赘述。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。

Claims (10)

1.一种分类网络训练方法,其特征在于,包括:
获取训练样本集;其中,所述训练样本集包括不同分类标签对应的正确分类样本以及错误分类样本;
按照批处理量大小从所述训练样本集中随机抽取第一样本集;其中,所述第一样本集对应多个第一训练样本;
将每一所述第一训练样本输入至待训练的第一分类网络中进行分类,获取所述每一所述第一训练样本对应的第一损失值;以及,将每一所述第一训练样本输入至待训练的第二分类网络中进行分类,获取每一所述第一训练样本对应的第二损失值;
对所述第一样本集对应的多个第一损失值进行递增顺序排序,并选取所述第一损失值排在前N位的所述第一训练样本作为第二样本集;以及,对所述所述第一样本集对应的多个第二损失值进行递增顺序排序,并选取所述第二损失值排在前N位的所述第一训练样本作为第三样本集;
通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络。
2.如权利要求1所述分类网络训练方法,其特征在于,在所述通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络之后,所述分类网络训练方法还包括:
更新所述N值;
重复执行所述按照批处理量大小从所述训练样本集中随机抽取第一样本集的步骤,直至模型收敛,获取训练好的第一分类网络以及第二分类网络。
3.如权利要求2所述分类网络训练方法,其特征在于,所述更新N值,包括:
按照预设比例更新所述N值;或者,
对所述第一样本集进行随机抽样,并根据得到的随机抽样集中正样本的占比估计所述损失抽样比例。
4.如权利要求1所述分类网络训练方法,其特征在于,在所述通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络之后,所述分类网络训练方法还包括:
更新所述批处理量大小;其中,所述更新后的批处理量大小小于更新前的批处理量大小;
重复执行所述按照批处理量大小从所述训练样本集中随机抽取第一样本集的步骤,直至模型收敛,获取训练好的第一分类网络以及第二分类网络。
5.如权利要求1所述分类网络训练方法,其特征在于,所述第一分类网络与第二分类网络的模型初始化参数不同,模型结构相同;所述第一分类网络以及所述第二分类网络均包括特征提取层以及分类层;所述特征提取层包括多个残差模块,每两个相邻的所述残差模块之间通过注意力机制模块连接。
6.如权利要求5所述分类网络训练方法,其特征在于,所述将每一所述第一训练样本输入至待训练的第一分类网络中进行分类,获取所述每一所述第一训练样本对应的第一损失值,包括:
将每一所述第一训练样本输入至所述残差模块进行特征提取,得到残差模块的输出;
将所述残差模块的输出输入至所述注意力机制模块中进行处理,得到所述注意力机制模块的输出;
将所述注意力机制模块的输出输入至下一残差模块,以使所述下一残差模块对所述注意力机制模块的输出进行特征提取,得到所述残差模块的输出;
重复执行所述所述将所述残差模块的输出输入至所述注意力机制模块中进行处理,得到所述注意力机制模块的输出的步骤,直至得到最后一个残差模块的输出;
将所述最后一个残差模块的输出输入至所述分类器进行分类,得到预测分类结果;
基于所述预测分类结果与所述第一训练样本对应的真实分类标签,以得到每一所述第一训练样本对应的第一损失值。
7.如权利要求1所述分类网络训练方法,其特征在于,所述获取训练样本集,包括:
获取每一所述分类标签对应的多个正确分类样本;
按照预设错误标记比例对多个所述正确标记样本的分类标签进行错误标记,得到每一所述分类标签对应的多个错误分类样本。
8.一种分类网络训练装置,其特征在于,包括:
样本集获取模块,用于获取训练样本集;其中,所述训练样本集包括不同分类标签对应的正确分类样本以及错误分类样本;
样本抽取模块,用于按照批处理量大小从所述训练样本集中随机抽取第一样本集;其中,所述第一样本集对应多个第一训练样本;
损失输出模块,用于将每一所述第一训练样本输入至待训练的第一分类网络中进行分类,获取所述每一所述第一训练样本对应的第一损失值;以及,将每一所述第一训练样本输入至待训练的第二分类网络中进行分类,获取每一所述第一训练样本对应的第二损失值;
样本过滤模块,用于对所述第一样本集对应的多个第一损失值进行递增顺序排序,并选取所述第一损失值排在前N位的所述第一训练样本作为第二样本集;以及,对所述所述第一样本集对应的多个第二损失值进行递增顺序排序,并选取所述第二损失值排在前N位的所述第一训练样本作为第三样本集;
联合训练模块,用于通过所述第二样本集训练所述第二分类网络;以及,通过所述第三样本集训练所述第一分类网络。
9.一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述分类网络训练方法的步骤。
10.一种计算机存储介质,所述计算机存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述分类网络训练方法的步骤。
CN202110745507.XA 2021-06-30 分类网络训练方法、装置、计算机设备及存储介质 Active CN113435525B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110745507.XA CN113435525B (zh) 2021-06-30 分类网络训练方法、装置、计算机设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110745507.XA CN113435525B (zh) 2021-06-30 分类网络训练方法、装置、计算机设备及存储介质

Publications (2)

Publication Number Publication Date
CN113435525A true CN113435525A (zh) 2021-09-24
CN113435525B CN113435525B (zh) 2024-06-21

Family

ID=

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108805185A (zh) * 2018-05-29 2018-11-13 腾讯科技(深圳)有限公司 模型的训练方法、装置、存储介质及计算机设备
CN111046959A (zh) * 2019-12-12 2020-04-21 上海眼控科技股份有限公司 模型训练方法、装置、设备和存储介质
CN111523596A (zh) * 2020-04-23 2020-08-11 北京百度网讯科技有限公司 目标识别模型训练方法、装置、设备以及存储介质
CN111860669A (zh) * 2020-07-27 2020-10-30 平安科技(深圳)有限公司 Ocr识别模型的训练方法、装置和计算机设备
CN112990432A (zh) * 2021-03-04 2021-06-18 北京金山云网络技术有限公司 目标识别模型训练方法、装置及电子设备

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108805185A (zh) * 2018-05-29 2018-11-13 腾讯科技(深圳)有限公司 模型的训练方法、装置、存储介质及计算机设备
CN111046959A (zh) * 2019-12-12 2020-04-21 上海眼控科技股份有限公司 模型训练方法、装置、设备和存储介质
CN111523596A (zh) * 2020-04-23 2020-08-11 北京百度网讯科技有限公司 目标识别模型训练方法、装置、设备以及存储介质
CN111860669A (zh) * 2020-07-27 2020-10-30 平安科技(深圳)有限公司 Ocr识别模型的训练方法、装置和计算机设备
CN112990432A (zh) * 2021-03-04 2021-06-18 北京金山云网络技术有限公司 目标识别模型训练方法、装置及电子设备

Similar Documents

Publication Publication Date Title
CN109241903B (zh) 样本数据清洗方法、装置、计算机设备及存储介质
CN109189767B (zh) 数据处理方法、装置、电子设备及存储介质
CN109086654B (zh) 手写模型训练方法、文本识别方法、装置、设备及介质
CN109740689B (zh) 一种图像语义分割的错误标注数据筛选方法及系统
CN111368874A (zh) 一种基于单分类技术的图像类别增量学习方法
CN113785305A (zh) 一种检测倾斜文字的方法、装置及设备
US20210390370A1 (en) Data processing method and apparatus, storage medium and electronic device
EP3620982B1 (en) Sample processing method and device
CN112862093A (zh) 一种图神经网络训练方法及装置
CN110909868A (zh) 基于图神经网络模型的节点表示方法和装置
CN110110845B (zh) 一种基于并行多级宽度神经网络的学习方法
CN113128536A (zh) 无监督学习方法、系统、计算机设备及可读存储介质
CN115810135A (zh) 样本分析的方法、电子设备、存储介质和程序产品
Li et al. Locality linear fitting one-class SVM with low-rank constraints for outlier detection
EP4343616A1 (en) Image classification method, model training method, device, storage medium, and computer program
CN111104831A (zh) 一种视觉追踪方法、装置、计算机设备以及介质
CN114626524A (zh) 目标业务网络确定方法、业务处理方法及装置
CN113283388A (zh) 活体人脸检测模型的训练方法、装置、设备及存储介质
CN109101984B (zh) 一种基于卷积神经网络的图像识别方法及装置
CN113971741A (zh) 一种图像标注方法、分类模型的训练方法、计算机设备
CN111507396A (zh) 缓解神经网络对未知类样本产生错误分类的方法及装置
CN108345943B (zh) 一种基于嵌入编码与对比学习的机器学习识别方法
CN113435525A (zh) 分类网络训练方法、装置、计算机设备及存储介质
CN113435525B (zh) 分类网络训练方法、装置、计算机设备及存储介质
CN115017819A (zh) 一种基于混合模型的发动机剩余使用寿命预测方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant