CN117787430A - 用于训练机器学习模型的方法、系统、设备、介质和程序 - Google Patents
用于训练机器学习模型的方法、系统、设备、介质和程序 Download PDFInfo
- Publication number
- CN117787430A CN117787430A CN202211151281.1A CN202211151281A CN117787430A CN 117787430 A CN117787430 A CN 117787430A CN 202211151281 A CN202211151281 A CN 202211151281A CN 117787430 A CN117787430 A CN 117787430A
- Authority
- CN
- China
- Prior art keywords
- training
- samples
- classification
- neural network
- marked
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 230
- 238000000034 method Methods 0.000 title claims abstract description 52
- 238000010801 machine learning Methods 0.000 title claims abstract description 39
- 238000012216 screening Methods 0.000 claims abstract description 51
- 239000003550 marker Substances 0.000 claims abstract description 32
- 238000013145 classification model Methods 0.000 claims abstract description 28
- 238000005070 sampling Methods 0.000 claims abstract description 23
- 238000013528 artificial neural network Methods 0.000 claims description 87
- 230000015654 memory Effects 0.000 claims description 33
- 230000006870 function Effects 0.000 claims description 24
- 238000012937 correction Methods 0.000 claims description 8
- 238000012935 Averaging Methods 0.000 claims description 3
- 238000004590 computer program Methods 0.000 claims description 3
- 238000010586 diagram Methods 0.000 description 11
- 230000008569 process Effects 0.000 description 8
- 230000009471 action Effects 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 230000007246 mechanism Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000012946 outsourcing Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 230000001360 synchronised effect Effects 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 230000008878 coupling Effects 0.000 description 1
- 238000010168 coupling process Methods 0.000 description 1
- 238000005859 coupling reaction Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000013518 transcription Methods 0.000 description 1
- 230000035897 transcription Effects 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Landscapes
- Image Analysis (AREA)
Abstract
本申请特别涉及用于训练机器学习模型的方法、系统、设备、介质和程序。该方法包括:对标记样本进行类别均衡采样,以获得标记训练样本集合;在多个训练迭代的每个训练迭代期间:使用级联的筛选模型对标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合;使用干净标记训练样本集合来训练分类模型;使用分类模型对错误标记训练样本集合中的每个错误标记训练样本的伪标签进行修正;将干净标记训练样本集合和修正后的错误标记训练样本集合组合为新的标记训练样本集合,并且返回到级联的筛选模型以进行筛选。如此,本申请实现类别均衡采样,可以更好地剔除噪声标记样本,避免过拟合,并且训练得到的机器学习模型的鲁棒性高。
Description
技术领域
本申请特别涉及用于训练机器学习模型的方法、系统、设备、介质和程序。
背景技术
训练深度神经网络通常需要大规模的标记数据。然而,大规模的标记数据通常是以廉价和不可信的方式(例如,互联网外包)获得。这些标记数据中会存在大量的噪声标记数据。噪声标记数据本质上是对于真值标签的破坏,不可避免地会对机器学习模型的训练产生干扰。特别是对于大模型容量的深度学习模型,它们在训练中会强行拟合标记数据,从而记住这些噪声标记数据,因而这对于训练深度神经网络来说是一个巨大挑战。
目前,存在联合训练两个同样类型的分类网络来剔除噪声标记数据的方案,这两个分类网络使用交叉熵作为损失函数,这对于噪声标记数据的损失会产生较大的梯度,容易产生过拟合,鲁棒性不够。并且,在噪声标记数据较多的情况下,随着训练进程的深入,两个网络可能都趋于同一个参数,从而与训练一个网络没有区别,失去了对于噪声标记数据的判别能力。另外,该方案忽略了标记数据中类别不均衡的现象,将所有标记数据的交叉熵损失一起进行比较,很容易将那些只有少量数据的类别总是错误当作噪声标记数据,从而对于那些类别缺乏学习,导致总体精度下降。
因此,亟需设计一种用于训练机器学习模型的方法,剔除噪声标记数据,避免过拟合,同时注意类别均衡采样,从而使得训练效果更好且更稳定。
发明内容
有鉴于此,本申请实施例提供用于训练机器学习模型的方法、系统、设备、介质和程序,实现类别均衡采样,解决了由于样本丢弃造成的训练类别不公平的问题,可以更好地剔除噪声标记样本,避免过拟合,并且训练得到的机器学习模型的鲁棒性高。
第一方面,本申请实施例提供了一种用于训练机器学习模型的方法,所述方法用于电子设备,并且所述方法包括:
对标记样本进行类别均衡采样,以获得标记训练样本集合;
在多个训练迭代的每个训练迭代期间:
使用级联的筛选模型对所述标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合;
使用所述干净标记训练样本集合来训练分类模型;
使用所述分类模型对所述错误标记训练样本集合中的每个错误标记训练样本的伪标签进行修正;
将所述干净标记训练样本集合和修正后的所述错误标记训练样本集合组合为新的所述标记训练样本集合,并且返回到所述级联的筛选模型以进行筛选。
本申请的实施例对标记样本进行类别均衡采样,以获得标记训练样本集合,该标记训练样本集合中的所有标记训练样本中的类别比例保持均衡,解决了由于样本丢弃造成的训练类别不公平的问题。
相比于联合训练,本申请的实施例使用级联的筛选模型对标记训练样本集合进行逐层筛选,可以选择多个网络传递训练,结合集成学习的思想进一步提升样本筛选能力,从而更好地剔除噪声标记样本。
本申请的实施例利用上一阶段筛选出的干净样本进行分类网络训练,对于噪声样本进行伪标签修正,进一步将所有样本送入上一阶段再进行样本的筛选以及样本修正和分类训练,如此反复,最终得到一个鲁棒性的分类模型。
在上述第一方面的一种可能实现中,对标记样本进行类别均衡采样,以获得标记训练样本集合进一步包括:
基于类别对所述标记样本进行分类,以获得多个分类标记样本集合;
从每个分类标记样本集合中选择第一数量的分类标记样本,以获得多个分类标记样本子集合;
从所述多个分类标记样本子集合中选择第二数量的分类标记样本子集合并组合,以获得所述标记训练样本集合。
在上述第一方面的一种可能实现中,所述级联的筛选模型包括第一深度神经网络、第二深度神经网络和第三深度神经网络,并且其中,使用级联的筛选模型对所述标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合进一步包括:
分别将所述标记训练样本集合输入所述第一深度神经网络和所述第二深度神经网络;
使用所述第一深度神经网络输出所述标记训练样本集合中的第一部分标记训练样本到所述第二深度神经网络,以及使用所述第二深度神经网络输出所述标记训练样本集合中的第二部分标记训练样本到所述第一深度神经网络,用于分别更新所述第一深度神经网络和所述第二深度神经网络的反向传播参数;
使用所述第二深度神经网络输出所述标记训练样本集合中的所述第二部分标记训练样本到所述第三深度神经网络,用于训练所述第三深度神经网络。
在上述第一方面的一种可能实现中,对于所述标记训练样本集合中的每个分类标记样本子集合,使用所述第一深度神经网络计算所述分类标记样本子集合中的每个分类标记样本的logits函数值,并且将logits函数值大于第一阈值的多个分类标记样本组合为所述第一部分标记训练样本,以及使用所述第二深度神经网络计算所述分类标记样本子集合中的每个分类标记样本的logits函数值,并且将logits函数值大于第二阈值的多个分类标记样本组合为所述第二部分标记训练样本。
本申请的实施例中,第一深度神经网络和第二深度神经网络将一个迷你批的标记训练样本集合中的每一类别中的logits函数值较大的样本作为干净的样本输出到同伴网络进行反向传播参数更新。
在上述第一方面的一种可能实现中,所述第一深度神经网络、所述第二深度神经网络和所述第三深度神经网络均使用平均绝对误差作为损失函数。
本申请的实施例利用呈V字型的平均绝对误差损失曲线,即,在各个损失值的梯度是一样的,使得模型不会对于损失较大的样本产生偏置,从而使得筛选网络对于噪声标记样本的鲁棒性更强,解决了神经网络在噪声标记样本学习中容易过拟合的问题。
在上述第一方面的一种可能实现中,所述分类模型包括第四深度神经网络,并且所述第四深度神经网络使用交叉熵作为损失函数。
在上述第一方面的一种可能实现中,所述第一深度神经网络、所述第二深度神经网络和所述第三深度神经网络均采用ResNet-18网络结构并且所述第四深度神经网络采用ResNet-50网络结构,或者所述第一深度神经网络、所述第二深度神经网络和所述第三深度神经网络均采用ResNet-50网络结构并且所述第四深度神经网络采用ResNet-101网络结构。
本申请的实施例将小网络结构应用到样本筛选网络并且将大网络结构应用到样本分类网络,从而实现配对。
在上述第一方面的一种可能实现中,使用所述干净标记训练样本集合来训练多个分类模型,并且其中,所述方法进一步包括:
分别使用训练后的所述多个分类模型进行分类预测;
对训练后的所述多个分类模型的分类预测结果进行平均,作为最终的分类预测结果。
本申请的实施例在测试过程中,使用集成学习的策略,设计两个大网络分类模型进行分类预测,取两个大网络的平均值作为最终的预测结果。
第二方面,本申请实施例提供了一种用于训练机器学习模型的系统,所述系统包括:
采样单元,对标记样本进行类别均衡采样,以获得标记训练样本集合;
在多个训练迭代的每个训练迭代期间:
筛选单元,使用级联的筛选模型对所述标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合;
训练单元,使用所述干净标记训练样本集合来训练分类模型;
修正单元,使用所述分类模型对所述错误标记训练样本集合中的每个错误标记训练样本的伪标签进行修正;
组合单元,将所述干净标记训练样本集合和修正后的所述错误标记训练样本集合组合为新的所述标记训练样本集合,并且返回到所述级联的筛选模型以进行筛选。
第三方面,本申请实施例提供了一种电子设备,所述电子设备包括:存储器,用于存储由所述电子设备的一个或多个处理器执行的指令;以及处理器,是所述电子设备的处理器之一,用于执行所述存储器中存储的指令以实现上述第一方面及其可能实现提供的任一种用于训练机器学习模型的方法。
第四方面,本申请实施例提供了一种可读介质,所述可读介质上存储有指令,所述指令在电子设备上执行时使所述电子设备执行上述第一方面及其可能实现提供的任一种用于训练机器学习模型的方法。
第五方面,本申请实施例提供了一种计算机程序产品,所述计算机程序产品包括计算机可执行指令,所述指令被处理器执行以实施上述第一方面及其可能实现提供的任一种用于训练机器学习模型的方法。
附图说明
图1根据本申请的一些实施例,示出了用于训练机器学习模型的方法的流程示意图;
图2根据本申请的一些实施例,示出了类别均衡采样的原理图;
图3根据本申请的一些实施例,示出了级联筛选的原理图;
图4根据本申请的一些实施例,示出了分类训练的原理图;
图5根据本申请的一些实施例,示出了用于训练机器学习模型的系统的结构示意图;
图6根据本申请的一些实施例,示出了电子设备的结构示意图。
具体实施方式
本申请的说明性实施例包括但不限于用于训练机器学习模型的方法、系统、设备、介质和程序。
为便于理解本申请的技术方案,首先介绍训练机器学习模型的应用场景。但是应该理解的是,下述内容仅是为了解释说明本申请的实施例,而不对本申请的实施例进行限制。
通常使用训练后的机器学习模型对图像数据帧(例如,用于对象检测、分类等)、音频数据帧(例如,用于转录、语音识别等)和/或文本(例如,用于自然语言分类等)等进行处理,以生成分类预测结果。
对机器学习模型中的深度神经网络进行高度精确的训练通常需要大量的标记训练样本。然而,获得高质量的标记训练样本(例如,经由人类注释)的过程经常是有挑战性和昂贵的。相比之下,例如通过互联网外包等方式而获得具有噪声(即,不精确的标签或伪标签)的标记训练样本通常要便宜得多。但是,因为许多深度神经网络具有高记忆能力,所以噪声标签可能变得突出并且导致过拟合。
本申请的各实施例所公开的技术方案中对数据的获取、存储、使用、处理等均符合国家法律法规的相关规定。
图1根据本申请的一些实施例,示出了用于训练机器学习模型的方法的流程示意图。如图1所示,该用于训练机器学习模型的方法包括如下步骤:
S101:对标记样本进行类别均衡采样,以获得标记训练样本集合。
可以理解,在一些实施例中,该标记样本(也可以称为标记数据)通过互联网外包等廉价的方式而获得。该标记样本包括非噪声标记样本(也可以称为干净标记样本)和噪声标记样本(也可以称为错误标记样本)。
该标记样本中的每个标记样本与给定标签相关联。对于非噪声标记样本,该给定标签为地面真实标签,而对于噪声标记样本,该给定标签为错误标签(也可以称为伪标签)。
该标记样本中的每个标记样本为图像数据帧、音频数据帧和/或文本。优选地,该标记样本中的每个标记样本为图像数据帧,并且相关联的给定标签为该图像数据帧的文本描述符。
可以理解,在一些实施例中,从标记样本中采样多个标记样本并且组合为标记训练样本集合,以分批地训练机器学习模型。
取决于标记样本的内容和实际需要,该标记样本可以被划分到不同的类别,这些类别可以从相关联的给定标签中提取。例如,对于包括人物的图像数据帧,该图像数据帧的文本描述符可以包括对该人物所穿着的服装的描述,因此可以根据该人物穿着的服装将该图像数据帧划分到不同的类别(诸如,雪纺裙、羽绒服和牛仔衫等)。
如果从标记样本中随机采样多个标记样本并且组合为标记训练样本集合,则该标记训练样本集合中的所有标记训练样本的类别分布是不均衡的。例如,如果标记样本中的雪纺裙、羽绒服和牛仔衫的数量比是8:10:1,则该标记训练样本集合中的所有标记训练样本中的雪纺裙、羽绒服和牛仔衫的数量比也应该符合上述比例。在这种情况下,如果将这些标记训练样本的交叉熵损失一起进行比较,则很容易将数量较少的牛仔衫类别总是错误当作噪声标记数据,从而对于牛仔衫类别缺乏学习,导致总体精度下降。
基于此,本申请的实施例对标记样本进行类别均衡采样,以获得标记训练样本集合,该标记训练样本集合中的所有标记训练样本中的类别比例保持均衡,解决了由于样本丢弃造成的训练类别不公平的问题。
图2根据本申请的一些实施例,示出了类别均衡采样的原理图。如图2所示,该类别均衡采样进一步包括:
基于类别对标记样本进行分类,以获得多个分类标记样本集合;
从每个分类标记样本集合中选择第一数量的分类标记样本,以获得多个分类标记样本子集合;
从多个分类标记样本子集合中选择第二数量的分类标记样本子集合并组合,以获得标记训练样本集合。
例如,取决于标记样本的内容和实际需要,确定第一至第五类别。可以理解的是,类别的数量可以为任何数量,在此不受限制。
然后,将所有标记样本根据这些类别进行统计,以获得第一至第五分类标记样本集合。基于从相关联的给定标签中提取的内容,每类分类标记样本集合中的所有分类标记样本应该属于同一类别。可以理解的是,由于噪声标记等原因,可能存在部分分类标记样本实际上不应该属于这一类别的情形,这会在后续的筛选过程中被归类到错误标记训练样本集合。
然后,分别从第一至第五分类标记样本集合中选择四个(即,第一数量)分类标记样本,以获得第一至第五分类标记样本子集合。如果某个分类标记样本集合中的所有分类标记样本的数量小于四个,则可以重复选择这些分类标记样本,直到选中四个为止。
然后,从第一至第五分类标记样本子集合中选择两个(即,第二数量)分类标记样本子集合(例如,第二分类标记样本子集合和第三分类标记样本子集合)并组合,以获得标记训练样本集合。在该标记训练样本集合中,各个类别的标记训练样本的数量保持一致,从而实现了类别均衡采样,解决了由于样本丢弃造成的训练类别不公平的问题。
优选地,从第一至第五分类标记样本子集合中选择一个(即,第二数量)分类标记样本子集合(例如,第二分类标记样本子集合)作为标记训练样本集合,从而确保该标记训练样本集合中的所有标记训练样本均属于同一类别,方便对同一类别进行训练。
可以理解的是,第一数量和第二数量可以为任何数量,在此不受限制。在分批地训练机器学习模型的情况下,第一数量和第二数量可以取决于批大小(batch size),从而获得一个迷你批(mini batch)的标记训练样本集合。
返回到图1,如图1所示,在多个训练迭代的每个训练迭代期间,该用于训练机器学习模型的方法还包括如下步骤:
S102:使用级联的筛选模型对标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合。
相比于联合训练,本申请的实施例使用级联的筛选模型对标记训练样本集合进行逐层筛选,可以选择多个网络传递训练,结合集成学习的思想进一步提升样本筛选能力,从而更好地剔除噪声标记样本。
图3根据本申请的一些实施例,示出了级联筛选的原理图。如图3所示,级联的筛选模型包括第一深度神经网络、第二深度神经网络和第三深度神经网络,并且该级联筛选进一步包括:
分别将标记训练样本集合输入第一深度神经网络和第二深度神经网络;
使用第一深度神经网络输出标记训练样本集合中的第一部分标记训练样本到第二深度神经网络,以及使用第二深度神经网络输出标记训练样本集合中的第二部分标记训练样本到第一深度神经网络,用于分别更新第一深度神经网络和第二深度神经网络的反向传播参数;
使用第二深度神经网络输出标记训练样本集合中的第二部分标记训练样本到第三深度神经网络,用于训练第三深度神经网络。
使用第一深度神经网络和第二深度神经网络进行联合训练,将彼此认为干净的样本输出到同伴网络进行反向传播参数更新,并且其中一个网络还将其认为干净的样本输出到下一个网络进行训练,利用课程学习的思想,逐步筛选出干净的样本。
可以理解,在一些实施例中,对于标记训练样本集合中的每个分类标记样本子集合,使用第一深度神经网络计算分类标记样本子集合中的每个分类标记样本的logits函数值,并且将logits函数值大于第一阈值的多个分类标记样本组合为第一部分标记训练样本,以及使用第二深度神经网络计算分类标记样本子集合中的每个分类标记样本的logits函数值,并且将logits函数值大于第二阈值的多个分类标记样本组合为第二部分标记训练样本。换句话说,第一深度神经网络和第二深度神经网络将一个迷你批的标记训练样本集合中的每一类别中的logits函数值较大的样本输出到同伴网络。可以理解的是,第一阈值和第二阈值可以相同,也可以不同,并且可以为任何值,在此不受限制。
可以理解,在一些实施例中,第一深度神经网络、第二深度神经网络和第三深度神经网络均使用平均绝对误差作为损失函数。
平均绝对误差损失曲线呈V字型,即,在各个损失值的梯度是一样的,使得模型不会对于损失较大的样本产生偏置,从而使得筛选网络对于噪声标记样本的鲁棒性更强,解决了神经网络在噪声标记样本学习中容易过拟合的问题。
返回到图1,如图1所示,在多个训练迭代的每个训练迭代期间,该用于训练机器学习模型的方法还包括如下步骤:
S103:使用干净标记训练样本集合来训练分类模型;
S104:使用分类模型对错误标记训练样本集合中的每个错误标记训练样本的伪标签进行修正;
S105:将干净标记训练样本集合和修正后的错误标记训练样本集合组合为新的标记训练样本集合,并且返回到级联的筛选模型以进行筛选。
本申请的实施例利用上一阶段筛选出的干净样本进行分类网络训练,对于噪声样本进行伪标签修正,进一步将所有样本送入上一阶段再进行样本的筛选以及样本修正和分类训练,如此反复,最终得到一个鲁棒性的分类模型。
图4根据本申请的一些实施例,示出了分类训练的原理图。如图4所示,分类模型包括第四深度神经网络,并且第四深度神经网络使用交叉熵作为损失函数。
可以理解,在一些实施例中,第一深度神经网络、第二深度神经网络和第三深度神经网络均采用小网络结构并且第四深度神经网络采用大网络结构,从而实现配对。参考图3和图4,第一深度神经网络、第二深度神经网络和第三深度神经网络均采用ResNet-18网络结构并且第四深度神经网络采用ResNet-50网络结构。可选地,第一深度神经网络、第二深度神经网络和第三深度神经网络均采用ResNet-50网络结构并且第四深度神经网络采用ResNet-101网络结构。
可以理解,在一些实施例中,使用干净标记训练样本集合来训练多个分类模型,并且该用于训练机器学习模型的方法进一步包括:
分别使用训练后的多个分类模型进行分类预测;
对训练后的多个分类模型的分类预测结果进行平均,作为最终的分类预测结果。
本申请的实施例在测试过程中,使用集成学习的策略,设计两个大网络分类模型进行分类预测,取两个大网络的平均值作为最终的预测结果。
图5根据本申请的一些实施例,示出了用于训练机器学习模型的系统的结构示意图。如图5所示,该用于训练机器学习模型的系统包括如下单元:
采样单元501,对标记样本进行类别均衡采样,以获得标记训练样本集合;
在多个训练迭代的每个训练迭代期间:
筛选单元502,使用级联的筛选模型对标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合;
训练单元503,使用干净标记训练样本集合来训练分类模型;
修正单元504,使用分类模型对错误标记训练样本集合中的每个错误标记训练样本的伪标签进行修正;
组合单元505,将干净标记训练样本集合和修正后的错误标记训练样本集合组合为新的标记训练样本集合,并且返回到级联的筛选模型以进行筛选。
第一实施方式是与本实施方式相对应的方法实施方式,本实施方式可与第一实施方式互相配合实施。第一实施方式中提到的相关技术细节在本实施方式中依然有效,为了减少重复,这里不再赘述。相应地,本实施方式中提到的相关技术细节也可应用在第一实施方式中。
图6根据本申请的一些实施例,示出了一种可以执行前述各实施例提供的用于训练机器学习模型的方法的电子设备600的结构示意图。如图6所示,电子设备600可以包括一个或多个处理器601、系统内存602、非易失性存储器(Non-Volatile Memory,NVM)603、通信接口604、输入/输出(I/O)设备605、以及用于耦接处理器601、系统内存602、非易失性存储器603、通信接口604和输入/输出(I/O)设备605的系统控制逻辑606。其中:
处理器601可以包括一个或多个单核或多核处理器。在一些实施例中,处理器601可以包括通用处理器和专用处理器(例如,图形处理器、应用处理器、基带处理器等)的任意组合。在一些实施例中,处理器601可以执行前述各实施例提供的用于训练机器学习模型的方法的指令,例如,类别均衡采样、级联筛选、分类训练、伪标签修正、新样本组合的指令等。
系统内存602是易失性存储器,例如,随机存取存储器(Random-Access Memory,RAM)、双倍数据率同步动态随机存取存储器(Double Data Rate Synchronous DynamicRandom Access Memory,DDR SDRAM)等。系统内存602用于临时存储数据和/或指令,例如,在一些实施例中,系统内存602可以用于临时存储各种训练样本等。
非易失性存储器603可以包括用于存储数据和/或指令的一个或多个有形的、非暂时性的计算机可读介质。在一些实施例中,非易失性存储器603可以包括闪存等任意合适的非易失性存储器和/或任意合适的非易失性存储设备,例如,硬盘驱动器(Hard DiskDrive,HDD)、光盘(Compact Disc,CD)、数字通用光盘(Digital Versatile Disc,DVD)、固态硬盘(Solid-State Drive,SSD)等。非易失性存储器603也可以是可移动存储介质,例如,安全数字(Secure Digital,SD)存储卡等。在一些实施例中,非易失性存储器603可以用于存储前述各实施例提供的用于训练机器学习模型的方法的指令,也可以用于存储最终的分类预测结果等。
特别地,系统内存602和非易失性存储器603可以分别包括:指令607的临时副本和永久副本。指令607可以包括:由处理器601中的至少一个执行时使电子设备600实现本申请各实施例提供的用于训练机器学习模型的方法。
网络接口604可以包括收发器,用于为电子设备600提供有线或无线通信接口,进而通过一个或多个网络与任意其他合适的设备进行通信。在一些实施例中,网络接口604可以集成于电子设备600的其他组件,例如,网络接口604可以集成于处理器601中。在一些实施例中,电子设备600可以通过网络接口604和其他设备通信,例如,不同设备之间可以通过各自的网络接口604耦接,从而实现不同设备中指令、数据的传递。
输入/输出(I/O)设备605可以包括用户界面,使得用户能够与电子设备600进行交互。例如,在一些实施例中,输入/输出(I/O)设备605可以包括显示器等输出设备,用于显示标记样本和最终的分类预测结果的用户界面,还可以包括键盘、鼠标、触摸屏等输入设备。用户可以通过用户界面以及键盘、鼠标、触摸屏等输入设备与标记样本和最终的分类预测结果进行交互,以便于配置和调整训练内容和训练模型等。
系统控制逻辑606可以包括任意合适的接口控制器,以为电子设备600的其他模块提供任意合适的接口。例如,在一些实施例中,系统控制逻辑606可以包括一个或多个存储器控制器,以提供连接到系统内存602和非易失性存储器603的接口。
在一些实施例中,处理器601中的至少一个可以与用于系统控制逻辑606的一个或多个控制器的逻辑封装在一起,以形成系统封装(System in Package,SiP)。在另一些实施例中,处理器601中的至少一个还可以与用于系统控制逻辑606的一个或多个控制器的逻辑集成在同一芯片上,以形成片上系统(System-on-Chip,SoC)。
可以理解,电子设备600可以是能够实现机器学习模型训练相关功能的任意电子设备,包括但不限于计算机、服务器、平板电脑、手持计算机等,本申请实施例不做限定。
可以理解,本申请实施例示出的电子设备600的结构并不构成对电子设备600的具体限定。在本申请另一些实施例中,电子设备600可以包括比图示更多或更少的部件,或者组合某些部件,或者拆分某些部件,或者不同的部件布置。图示的部件可以以硬件、软件或软件和硬件的组合实现。
本申请公开的机制的各实施例可以被实现在硬件、软件、固件或这些实现方法的组合中。本申请的实施例可实现为在可编程系统上执行的计算机程序或程序代码,该可编程系统包括至少一个处理器、存储系统(包括易失性和非易失性存储器和/或存储元件)、至少一个输入设备以及至少一个输出设备。
可将程序代码应用于输入指令,以执行本申请描述的各功能并生成输出信息。可以按已知方式将输出信息应用于一个或多个输出设备。为了本申请的目的,处理系统包括具有诸如例如数字信号处理器(DSP)、微控制器、专用集成电路(ASIC)或微处理器之类的处理器的任何系统。
程序代码可以用高级程序化语言或面向对象的编程语言来实现,以便与处理系统通信。在需要时,也可用汇编语言或机器语言来实现程序代码。事实上,本申请中描述的机制不限于任何特定编程语言的范围。在任一情形下,该语言可以是编译语言或解释语言。
在一些情况下,所公开的实施例可以以硬件、固件、软件或其任何组合来实现。所公开的实施例还可以被实现为由一个或多个暂时或非暂时性机器可读(例如,计算机可读)存储介质承载或存储在其上的指令,其可以由一个或多个处理器读取和执行。例如,指令可以通过网络或通过其他计算机可读介质分发。因此,机器可读介质可以包括用于以机器(例如,计算机)可读的形式存储或传输信息的任何机制,包括但不限于,软盘、光盘、光碟、只读存储器(CD-ROMs)、磁光盘、只读存储器(ROM)、随机存取存储器(RAM)、可擦除可编程只读存储器(EPROM)、电可擦除可编程只读存储器(EEPROM)、磁卡或光卡、闪存、或用于利用因特网以电、光、声或其他形式的传播信号来传输信息(例如,载波、红外信号数字信号等)的有形的机器可读存储器。因此,机器可读介质包括适合于以机器(例如,计算机)可读的形式存储或传输电子指令或信息的任何类型的机器可读介质。
在附图中,可以以特定布置和/或顺序示出一些结构或方法特征。然而,应该理解,可能不需要这样的特定布置和/或排序。而是,在一些实施例中,这些特征可以以不同于说明性附图中所示的方式和/或顺序来布置。另外,在特定图中包括结构或方法特征并不意味着暗示在所有实施例中都需要这样的特征,并且在一些实施例中,可以不包括这些特征或者可以与其他特征组合。
需要说明的是,本申请各设备实施例中提到的各单元/模块都是逻辑单元/模块,在物理上,一个逻辑单元/模块可以是一个物理单元/模块,也可以是一个物理单元/模块的一部分,还可以以多个物理单元/模块的组合实现,这些逻辑单元/模块本身的物理实现方式并不是最重要的,这些逻辑单元/模块所实现的功能的组合才是解决本申请所提出的技术问题的关键。此外,为了突出本申请的创新部分,本申请上述各设备实施例并没有将与解决本申请所提出的技术问题关系不太密切的单元/模块引入,这并不表明上述设备实施例并不存在其它的单元/模块。
需要说明的是,在本专利的示例和说明书中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
虽然通过参照本申请的某些优选实施例,已经对本申请进行了图示和描述,但本领域的普通技术人员应该明白,可以在形式上和细节上对其作各种改变,而不偏离本申请的精神和范围。
Claims (12)
1.一种用于训练机器学习模型的方法,其特征在于,所述方法用于电子设备,并且所述方法包括:
对标记样本进行类别均衡采样,以获得标记训练样本集合;
在多个训练迭代的每个训练迭代期间:
使用级联的筛选模型对所述标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合;
使用所述干净标记训练样本集合来训练分类模型;
使用所述分类模型对所述错误标记训练样本集合中的每个错误标记训练样本的伪标签进行修正;
将所述干净标记训练样本集合和修正后的所述错误标记训练样本集合组合为新的所述标记训练样本集合,并且返回到所述级联的筛选模型以进行筛选。
2.根据权利要求1所述的方法,其特征在于,对标记样本进行类别均衡采样,以获得标记训练样本集合进一步包括:
基于类别对所述标记样本进行分类,以获得多个分类标记样本集合;
从每个分类标记样本集合中选择第一数量的分类标记样本,以获得多个分类标记样本子集合;
从所述多个分类标记样本子集合中选择第二数量的分类标记样本子集合并组合,以获得所述标记训练样本集合。
3.根据权利要求2所述的方法,其特征在于,所述级联的筛选模型包括第一深度神经网络、第二深度神经网络和第三深度神经网络,并且其中,使用级联的筛选模型对所述标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合进一步包括:
分别将所述标记训练样本集合输入所述第一深度神经网络和所述第二深度神经网络;
使用所述第一深度神经网络输出所述标记训练样本集合中的第一部分标记训练样本到所述第二深度神经网络,以及使用所述第二深度神经网络输出所述标记训练样本集合中的第二部分标记训练样本到所述第一深度神经网络,用于分别更新所述第一深度神经网络和所述第二深度神经网络的反向传播参数;
使用所述第二深度神经网络输出所述标记训练样本集合中的所述第二部分标记训练样本到所述第三深度神经网络,用于训练所述第三深度神经网络。
4.根据权利要求3所述的方法,其特征在于,对于所述标记训练样本集合中的每个分类标记样本子集合,使用所述第一深度神经网络计算所述分类标记样本子集合中的每个分类标记样本的logits函数值,并且将logits函数值大于第一阈值的多个分类标记样本组合为所述第一部分标记训练样本,以及使用所述第二深度神经网络计算所述分类标记样本子集合中的每个分类标记样本的logits函数值,并且将logits函数值大于第二阈值的多个分类标记样本组合为所述第二部分标记训练样本。
5.根据权利要求3或4所述的方法,其特征在于,所述第一深度神经网络、所述第二深度神经网络和所述第三深度神经网络均使用平均绝对误差作为损失函数。
6.根据权利要求5所述的方法,其特征在于,所述分类模型包括第四深度神经网络,并且所述第四深度神经网络使用交叉熵作为损失函数。
7.根据权利要求6所述的方法,其特征在于,所述第一深度神经网络、所述第二深度神经网络和所述第三深度神经网络均采用ResNet-18网络结构并且所述第四深度神经网络采用ResNet-50网络结构,或者所述第一深度神经网络、所述第二深度神经网络和所述第三深度神经网络均采用ResNet-50网络结构并且所述第四深度神经网络采用ResNet-101网络结构。
8.根据权利要求1所述的方法,其特征在于,使用所述干净标记训练样本集合来训练多个分类模型,并且其中,所述方法进一步包括:
分别使用训练后的所述多个分类模型进行分类预测;
对训练后的所述多个分类模型的分类预测结果进行平均,作为最终的分类预测结果。
9.一种用于训练机器学习模型的系统,其特征在于,所述系统包括:
采样单元,对标记样本进行类别均衡采样,以获得标记训练样本集合;
在多个训练迭代的每个训练迭代期间:
筛选单元,使用级联的筛选模型对所述标记训练样本集合进行筛选,以获得干净标记训练样本集合和错误标记训练样本集合;
训练单元,使用所述干净标记训练样本集合来训练分类模型;
修正单元,使用所述分类模型对所述错误标记训练样本集合中的每个错误标记训练样本的伪标签进行修正;
组合单元,将所述干净标记训练样本集合和修正后的所述错误标记训练样本集合组合为新的所述标记训练样本集合,并且返回到所述级联的筛选模型以进行筛选。
10.一种电子设备,其特征在于,所述电子设备包括:
存储器,用于存储由所述电子设备的一个或多个处理器执行的指令;以及
处理器,是所述电子设备的处理器之一,用于执行所述存储器中存储的指令以实现权利要求1至8中任一项所述的用于训练机器学习模型的方法。
11.一种可读介质,其特征在于,所述可读介质上存储有指令,所述指令在电子设备上执行时使所述电子设备执行权利要求1至8中任一项所述的用于训练机器学习模型的方法。
12.一种计算机程序产品,其特征在于,所述计算机程序产品包括计算机可执行指令,所述指令被处理器执行以实施权利要求1至8中任一项所述的用于训练机器学习模型的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211151281.1A CN117787430A (zh) | 2022-09-21 | 2022-09-21 | 用于训练机器学习模型的方法、系统、设备、介质和程序 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211151281.1A CN117787430A (zh) | 2022-09-21 | 2022-09-21 | 用于训练机器学习模型的方法、系统、设备、介质和程序 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117787430A true CN117787430A (zh) | 2024-03-29 |
Family
ID=90385528
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211151281.1A Pending CN117787430A (zh) | 2022-09-21 | 2022-09-21 | 用于训练机器学习模型的方法、系统、设备、介质和程序 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117787430A (zh) |
-
2022
- 2022-09-21 CN CN202211151281.1A patent/CN117787430A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109471938B (zh) | 一种文本分类方法及终端 | |
US11886799B2 (en) | Determining functional and descriptive elements of application images for intelligent screen automation | |
CN106778241B (zh) | 恶意文件的识别方法及装置 | |
US11810380B2 (en) | Methods and apparatus to decode documents based on images using artificial intelligence | |
CN109829306A (zh) | 一种优化特征提取的恶意软件分类方法 | |
JP2019519042A (ja) | 情報をプッシュする方法及びデバイス | |
JP7382350B2 (ja) | 効率的なラベル伝搬のためのアンサンブルベースのデータキュレーションパイプライン | |
CN111753290A (zh) | 软件类型的检测方法及相关设备 | |
CN113052577A (zh) | 一种区块链数字货币虚拟地址的类别推测方法及系统 | |
CN114978624A (zh) | 钓鱼网页检测方法、装置、设备及存储介质 | |
CN111582315A (zh) | 样本数据处理方法、装置及电子设备 | |
CN111062490A (zh) | 一种包含隐私数据的网络数据的处理方法及装置 | |
CN117787430A (zh) | 用于训练机器学习模型的方法、系统、设备、介质和程序 | |
CN110879832A (zh) | 目标文本检测方法、模型训练方法、装置及设备 | |
CN114467144A (zh) | 减少测序平台特异性错误的体细胞突变检测装置及方法 | |
CN113298185B (zh) | 模型训练方法、异常文件检测方法、装置、设备及介质 | |
US20220391630A1 (en) | Methods, systems, articles of manufacture, and apparatus to extract shape features based on a structural angle template | |
CN112308141B (zh) | 一种扫描票据分类方法、系统及可读存储介质 | |
CN111291726B (zh) | 医疗票据分拣方法、装置、设备和介质 | |
CN114443878A (zh) | 图像分类方法、装置、设备及存储介质 | |
CN113853243A (zh) | 游戏道具分类及神经网络的训练方法和装置 | |
Monarev et al. | Prior classification of stego containers as a new approach for enhancing steganalyzers accuracy | |
CN110866543B (zh) | 图片检测及图片分类模型的训练方法和装置 | |
US11797893B2 (en) | Machine learning for generating an integrated format data record | |
CN106446902A (zh) | 非文字图像识别方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |