CN112633515A - 基于样本剔除的模型训练方法及设备 - Google Patents
基于样本剔除的模型训练方法及设备 Download PDFInfo
- Publication number
- CN112633515A CN112633515A CN202011493895.9A CN202011493895A CN112633515A CN 112633515 A CN112633515 A CN 112633515A CN 202011493895 A CN202011493895 A CN 202011493895A CN 112633515 A CN112633515 A CN 112633515A
- Authority
- CN
- China
- Prior art keywords
- training set
- loss
- samples
- training
- sample
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 427
- 238000000034 method Methods 0.000 title claims abstract description 46
- 230000008030 elimination Effects 0.000 title claims abstract description 14
- 238000003379 elimination reaction Methods 0.000 title claims abstract description 14
- 238000007619 statistical method Methods 0.000 claims description 54
- 238000012216 screening Methods 0.000 claims description 27
- 230000008859 change Effects 0.000 claims description 22
- 230000015654 memory Effects 0.000 claims description 18
- 238000002474 experimental method Methods 0.000 abstract description 5
- 238000012360 testing method Methods 0.000 abstract description 3
- 230000000875 corresponding effect Effects 0.000 description 43
- 230000006870 function Effects 0.000 description 6
- 230000005291 magnetic effect Effects 0.000 description 5
- 238000004590 computer program Methods 0.000 description 4
- 238000004458 analytical method Methods 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 238000013136 deep learning model Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000008569 process Effects 0.000 description 2
- 230000002159 abnormal effect Effects 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 230000002596 correlated effect Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000005192 partition Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 230000001052 transient effect Effects 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
- 230000003936 working memory Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Medical Informatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Complex Calculations (AREA)
Abstract
本发明的目的是提供一种基于样本剔除的模型训练方法及设备,本发明通过在原始训练集中剔除容易样本,原始训练集中的容易样本剔除后,原始训练集中剩余的样本为困难样本,将困难样本作为新的训练集,使用所述新的训练集对所述模型进行的新的一轮迭代训练,直至模型收敛。本发明通过剔除容易样本,在模型训练后期增大了困难样本在训练集中所占比例,是对训练集中的样本分布的动态调整,从而加强模型对困难样本的学习。实验表明,该方法能够提升模型在训练集与测试集上的性能。
Description
技术领域
本发明涉及计算机领域,尤其涉及一种基于样本剔除的模型训练方法及设备。
背景技术
在深度学习领域中,训练集所使用的样本分布直接影响了模型训练的过程,从而决定了深度学习模型的好坏。使用一个分布均衡的训练集进行训练,通常可以使模型具有更好的预测性能与泛化能力。
传统的深度学习模型训练方法通常使用固定的训练集进行训练,这样的训练方式往往会使模型更加稳定。但是当训练集中出现样本分布不均衡等问题时,模型在各个类型样本上的表现往往不同。存在某些模式的样本更容易被模型学习,而另一些模式的样本不易学习。因此模型在不同类型的样本上的收敛速度是不同的。当训练集样本固定时,各类型样本在训练集中所占比例固定不变。当模型在易学习样本上已经收敛时,继续使用易学样本对模型进行训练会干扰模型对困难样本的学习,从而导致模型整体性能的提升到达瓶颈。如果能够动态改变训练集的样本分布,在模型训练的不同时期加强模型对不同类型样本的学习,则可以大大优化模型的性能。
视频预测类任务不同于一般的判别类任务,无法根据简单的准确率等方式确定容易样本与困难样本。因此若想在视频预测类任务上进行样本遗忘实验,必须先制定判定难易样本的标准。在模型的训练过程中,损失函数直接决定了模型学习的方向。因此需要先对不同样本对应损失函数的数值进行深入的分析,确定划分难易样本的阈值,再进行样本剔除实验。
发明内容
本发明的一个目的是提供一种基于样本剔除的模型训练方法及设备。
根据本发明的一个方面,提供了一种基于样本剔除的模型训练方法及设备,该方法包括:
使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
进一步的,上述方法中,对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果,包括:
对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围,基于模型训练目标和所有样本的损失值的取值范围,初步筛选一个或多个类型的损失;
基于所有样本的损失值的取值范围,分析在当前轮的迭代训练的各次迭代训练阶段中损失值的分布的变化情况,基于所述变化情况确定原始训练集中所有样本的整体损失值收敛的训练迭代次数的次数阈值N,N大于等于1;
若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,基于所述相关性从所述初步筛选损失的类型中进一步筛选出最终的一个或多个类型的损失。
进一步的,上述方法中,对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围,基于模型训练目标和所有样本的损失值的取值范围,初步筛选一个或多个类型的损失,包括:
对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围;
基于所有样本的损失值的取值范围获取高损失值的样本和低损失值的样本;
基于获取到的高损失值的样本和低损失值的样本,初步筛选一个或多个类型的损失。
进一步的,上述方法中,若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,基于所述相关性从所述初步筛选损失的类型中进一步筛选出最终的一个或多个类型的损失,包括:
若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,得到相关性小于预设相关性阈值的一个或多个类型的损失;
将相关性小于预设相关性阈值的一个或多个类型的损失作为筛选出的最终的一个或多个类型的损失。
进一步的,上述方法中,基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围,包括:
确定进一步筛选出的最终的一个或多个类型的损失的对应的损失值的容易样本阈值范围。
进一步的,上述方法中,确定进一步筛选出的最终的一个或多个类型的损失的对应的损失值的容易样本阈值范围,包括:
对筛选出的最终的每一类型的损失,分析最终的每一类型的损失的损失值的取值范围在当前轮的迭代训练的各次迭代训练阶段的变化,当某次次迭代训练阶段的损失值的取值范围相对变化小于预设变化阈值时,确认模型此时已经收敛;
选取模型收敛时相对变化小于预设变化阈值的损失值的取值范围的预设百分位数作为容易样本阈值范围。
进一步的,上述方法中,使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,基于所述判断结果,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集,包括:
步骤S31,使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练的第一次迭代训练阶段,判断所述原始训练集中的每个样本的进一步筛选出最终的一个或多个类型的损失的损失值是否在对应的容易样本阈值范围之内;
步骤S32,若某个样本的某个进一步筛选出最终的一个或多个类型的损失的损失值在对应的容易样本阈值范围之内,则将该样本确定为容易样本;
步骤S33,对所述容易样本的当前迭代训练阶段的次数计数,以得到容易样本的计数值,判断容易样本的计数值是否超过训练迭代次数的次数阈值N;
步骤S34,若容易样本的计数值未超过训练迭代次数的次数阈值N,则使用原始训练集对所述模型进行的当前轮迭代训练的下一次的迭代训练后,对该容易样本重复从步骤S33开始执行;
步骤S35,若容易样本的计数值超过训练迭代次数的次数阈值N,将该容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集。
根据本发明的另一方面,还提供一种基于样本剔除的模型训练设备,其中,该设备包括:
统计组装置,用于使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
范围确定装置,用于基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
剔除装置,用于使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
训练装置,用于使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
根据本发明的另一方面,还提供一种基于计算的设备,其中,包括:
处理器;以及
被安排成存储计算机可执行指令的存储器,所述可执行指令在被执行时使所述处理器:
使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
根据本发明的另一方面,还提供一种计算机可读存储介质,其上存储有计算机可执行指令,其中,该计算机可执行指令被处理器执行时使得该处理器:
使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
与现有技术相比,本发明通过在原始训练集中剔除容易样本,原始训练集中的容易样本剔除后,原始训练集中剩余的样本为困难样本,将困难样本作为新的训练集,使用所述新的训练集对所述模型进行的新的一轮迭代训练,直至模型收敛。本发明通过剔除容易样本,在模型训练后期增大了困难样本在训练集中所占比例,是对训练集中的样本分布的动态调整,从而加强模型对困难样本的学习。实验表明,该方法能够提升模型在训练集与测试集上的性能。
附图说明
通过阅读参照以下附图所作的对非限制性实施例所作的详细描述,本发明的其它特征、目的和优点将会变得更明显:
图1示出本发明一实施例的基于样本剔除的模型训练方法的损失函数分析方法流程图;
图2示出本发明一实施例的基于样本剔除的模型训练方法的训练迭代过程中取出样本的流程图;
图3示出本发明一实施例的基于样本剔除的模型训练方法的流程图。
附图中相同或相似的附图标记代表相同或相似的部件。
具体实施方式
下面结合附图对本发明作进一步详细描述。
在本申请一个典型的配置中,终端、服务网络的设备和可信方均包括一个或多个处理器(CPU)、输入/输出接口、网络接口和内存。
内存可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM)。内存是计算机可读介质的示例。
计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括非暂存电脑可读媒体(transitory media),如调制的数据信号和载波。
如图1、2和3所示,本发明提供一种基于样本剔除的模型训练方法,所述方法包括:
步骤S1,使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值(loss),并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
在此,所述原始训练集中的样本可以是视频样本;
步骤S2,基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
在此,可以根据对模型训练不同阶段loss的统计学分析结果,确定划分容易样本的一个或多个类型的损失对应的损失值(loss)的阈值范围;
步骤S3,使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
步骤S4,使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
在此,本发明通过在原始训练集中剔除容易样本,原始训练集中的容易样本剔除后,原始训练集中剩余的样本为困难样本,将困难样本作为新的训练集,使用所述新的训练集对所述模型进行的新的一轮迭代训练,直至模型收敛。本发明通过剔除容易样本,在模型训练后期增大了困难样本在训练集中所占比例,是对训练集中的样本分布的动态调整,从而加强模型对困难样本的学习。实验表明,该方法能够提升模型在训练集与测试集上的性能。
本发明的基于样本剔除的模型训练方法一实施例中,步骤S1中,对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果,包括:
步骤S11中,对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围,基于模型训练目标和所有样本的损失值的取值范围,初步筛选一个或多个类型的损失;
例如,可以初步筛选到一个a类型的损失;或者可以初步筛选到三个类型的损失,如a类型、b类型和c类型的损失;
步骤S12中,基于所有样本的损失值的取值范围,分析在当前轮的迭代训练的各次迭代训练阶段中损失值的分布的变化情况,基于所述变化情况确定原始训练集中所有样本的整体损失值收敛的训练迭代次数(epoch)的次数阈值N,N大于等于1;
步骤S13中,若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,基于所述相关性从所述初步筛选损失的类型中进一步筛选出最终的一个或多个类型的损失;
例如,可以初步筛选到三个类型的损失,如a类型、b类型和c类型的损失,那么可以分析a、b、c类型的损失之间的相关性,可以所述初步筛选损失的类型中进一步筛选出最终的类型的损失为b和c类型的损失。
本发明的基于样本剔除的模型训练方法一实施例中,步骤S11中,对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围,基于模型训练目标和所有样本的损失值的取值范围,初步筛选一个或多个类型的损失,包括:
步骤S111,对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围;
步骤S112,基于所有样本的损失值的取值范围获取高损失值的样本和低损失值的样本;
步骤S113,基于获取到的高损失值的样本和低损失值的样本,初步筛选一个或多个类型的损失。
在此,可以通过对异常点包括高损失值和低损失值的样本的具体案例分析,确定损失值的具体取值范围,基于损失值的具体取值范围确定对模型训练目标方向的影响的一个或多个类型的损失,
具体的,可以包括:根据损失值的整体分布,确定损失值的在训练集所有样本上的损失值的取值范围。可以结合损失值的分布直方图与箱式图,选取损失值的分布中的异常点样本包括高损失值的样本(例如可以是所有样本中损失值最高的5%的样本)以及低损失值的样本(例如可以是所有样本中损失值最低的5%的样本)进行样例分析。具体可以通过直方图中得到所有样本上的损失值的取值范围,通过箱式图得到损失值的均值、方差和百分位点等等。
可以结合模型预测结果的视觉效果,确定高损失值和低损失值对应的样本的特征,从而可以分析得出低损失值和高损失值对模型训练训练方向的影响。例如,在临近预报任务中,通过对临近预报模型中所使用的感知损失(vgg loss)进行分析,发现类型为感知损失的感知损失值较高的样本对应的预测效果图中,高色阶像素区域预测效果较差。因此可以发现感知损失这一类型可以优化模型对高色阶像素区域预测(训练目标方向)的准确度。所以在此可以初步筛选到一个类型的损失为感知损失(vgg loss)。
本发明的基于样本剔除的模型训练方法一实施例中,步骤S13中,若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,基于所述相关性从所述初步筛选损失的类型中进一步筛选出最终的一个或多个类型的损失,包括:
步骤S131,若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,得到相关性小于预设相关性阈值(相关性较低)的一个或多个类型的损失;
步骤S132,将相关性较低的一个或多个类型的损失作为筛选出的最终的一个或多个类型的损失。
在此,可以分析不同类型的损失间的相关性,确定用于划分难易样本的一个或多个类型的损失,具体可以包括:
通过绘制不同类型的损失间的散点图,计算不同类型的损失的皮尔森相关系数等方式分析不同loss之间的相关性。高度相关的类型的损失之间所携带的信息高度重合,因此可以选择其中的一个类型的损失;而相关性较低的类型的损失之间所携带的信息差距较大,可能对模型的优化方向贡献不同,因此应当全部选取。根据不同类型的损失之间的相关性可以确定用于划分难易样本的损失的类型。
本发明的基于样本剔除的模型训练方法一实施例中,步骤S2,基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围,包括:
步骤S21,确定进一步筛选出的最终的一个或多个类型的损失的对应的损失值的容易样本阈值范围。
本发明的基于样本剔除的模型训练方法一实施例中,步骤S21,确定进一步筛选出的最终的一个或多个类型的损失的对应的损失值的容易样本阈值范围,包括:
步骤S211,对S13中筛选出的最终的每一类型的损失,分析最终的每一类型的损失的损失值的取值范围在当前轮的迭代训练的各次迭代训练阶段的变化,当某次次迭代训练阶段的损失值的取值范围相对变化小于预设变化阈值如5%时,确认模型此时已经收敛;
步骤S212,选取模型收敛时相对变化小于预设变化阈值的损失值的取值范围的预设百分位数(如90%)作为容易样本阈值范围。
在此,当某个样本的损失值小于该容易样本阈值范围时,认为该样本的该损失值收敛,否则该样本的该损失值未收敛。
本发明的基于样本剔除的模型训练方法一实施例中,步骤S3,步骤S3,使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,基于所述判断结果,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集,包括:
步骤S31,使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练的第一次迭代训练阶段,判断所述原始训练集中的每个样本的进一步筛选出最终的一个或多个类型的损失的损失值是否在对应的容易样本阈值范围之内;
步骤S32,若某个样本的某个进一步筛选出最终的一个或多个类型的损失的损失值在对应的容易样本阈值范围之内,则将该样本确定为容易样本;
步骤S33,对所述容易样本的当前迭代训练阶段的次数计数,以得到容易样本的计数值,判断容易样本的计数值是否超过训练迭代次数(epoch)的次数阈值N;
步骤S34,若容易样本的计数值未超过训练迭代次数(epoch)的次数阈值N,则使用原始训练集对所述模型进行的当前轮迭代训练的下一次的迭代训练后,对该容易样本重复从步骤S33开始执行;
步骤S35,若容易样本的计数值超过训练迭代次数(epoch)的次数阈值N,即该容易样本在连续N次训练迭代中均为容易样本,将该容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
步骤S4,使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛,包括:
使用新的训练集对所述模型进行的当前轮迭代训练的后续次的迭代训练,直至模型收敛。
根据本发明的另一方面,还提供一种基于样本剔除的模型训练设备,其中,该设备包括:
统计组装置,用于使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
范围确定装置,用于基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
剔除装置,用于使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
训练装置,用于使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
根据本发明的另一方面,还提供一种基于计算的设备,其中,包括:
处理器;以及
被安排成存储计算机可执行指令的存储器,所述可执行指令在被执行时使所述处理器:
使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
根据本发明的另一方面,还提供一种计算机可读存储介质,其上存储有计算机可执行指令,其中,该计算机可执行指令被处理器执行时使得该处理器:
使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
本发明的各设备和存储介质实施例的详细内容,具体可参见各方法实施例的对应部分,在此,不再赘述。
显然,本领域的技术人员可以对本申请进行各种改动和变型而不脱离本申请的精神和范围。这样,倘若本申请的这些修改和变型属于本申请权利要求及其等同技术的范围之内,则本申请也意图包含这些改动和变型在内。
需要注意的是,本发明可在软件和/或软件与硬件的组合体中被实施,例如,可采用专用集成电路(ASIC)、通用目的计算机或任何其他类似硬件设备来实现。在一个实施例中,本发明的软件程序可以通过处理器执行以实现上文所述步骤或功能。同样地,本发明的软件程序(包括相关的数据结构)可以被存储到计算机可读记录介质中,例如,RAM存储器,磁或光驱动器或软磁盘及类似设备。另外,本发明的一些步骤或功能可采用硬件来实现,例如,作为与处理器配合从而执行各个步骤或功能的电路。
另外,本发明的一部分可被应用为计算机程序产品,例如计算机程序指令,当其被计算机执行时,通过该计算机的操作,可以调用或提供根据本发明的方法和/或技术方案。而调用本发明的方法的程序指令,可能被存储在固定的或可移动的记录介质中,和/或通过广播或其他信号承载媒体中的数据流而被传输,和/或被存储在根据所述程序指令运行的计算机设备的工作存储器中。在此,根据本发明的一个实施例包括一个装置,该装置包括用于存储计算机程序指令的存储器和用于执行程序指令的处理器,其中,当该计算机程序指令被该处理器执行时,触发该装置运行基于前述根据本发明的多个实施例的方法和/或技术方案。
对于本领域技术人员而言,显然本发明不限于上述示范性实施例的细节,而且在不背离本发明的精神或基本特征的情况下,能够以其他的具体形式实现本发明。因此,无论从哪一点来看,均应将实施例看作是示范性的,而且是非限制性的,本发明的范围由所附权利要求而不是上述说明限定,因此旨在将落在权利要求的等同要件的含义和范围内的所有变化涵括在本发明内。不应将权利要求中的任何附图标记视为限制所涉及的权利要求。此外,显然“包括”一词不排除其他单元或步骤,单数不排除复数。装置权利要求中陈述的多个单元或装置也可以由一个单元或装置通过软件或者硬件来实现。第一,第二等词语用来表示名称,而并不表示任何特定的顺序。
Claims (10)
1.一种基于样本剔除的模型训练方法,其中,该方法包括:
使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
2.根据权利要求1所述的方法,其中,对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果,包括:
对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围,基于模型训练目标和所有样本的损失值的取值范围,初步筛选一个或多个类型的损失;
基于所有样本的损失值的取值范围,分析在当前轮的迭代训练的各次迭代训练阶段中损失值的分布的变化情况,基于所述变化情况确定原始训练集中所有样本的整体损失值收敛的训练迭代次数的次数阈值N,N大于等于1;
若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,基于所述相关性从所述初步筛选损失的类型中进一步筛选出最终的一个或多个类型的损失。
3.根据权利要求2所述的方法,其中,对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围,基于模型训练目标和所有样本的损失值的取值范围,初步筛选一个或多个类型的损失,包括:
对原始训练集中所有样本的损失值进行统计学分析,得到所有样本的损失值的取值范围;
基于所有样本的损失值的取值范围获取高损失值的样本和低损失值的样本;
基于获取到的高损失值的样本和低损失值的样本,初步筛选一个或多个类型的损失。
4.根据权利要求2所述的方法,其中,若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,基于所述相关性从所述初步筛选损失的类型中进一步筛选出最终的一个或多个类型的损失,包括:
若初步筛选损失的类型为多个,分析不同类型的损失之间的相关性,得到相关性小于预设相关性阈值的一个或多个类型的损失;
将相关性小于预设相关性阈值的一个或多个类型的损失作为筛选出的最终的一个或多个类型的损失。
5.根据权利要求2所述的方法,其中,基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围,包括:
确定进一步筛选出的最终的一个或多个类型的损失的对应的损失值的容易样本阈值范围。
6.根据权利要求5所述的方法,其中,确定进一步筛选出的最终的一个或多个类型的损失的对应的损失值的容易样本阈值范围,包括:
对筛选出的最终的每一类型的损失,分析最终的每一类型的损失的损失值的取值范围在当前轮的迭代训练的各次迭代训练阶段的变化,当某次次迭代训练阶段的损失值的取值范围相对变化小于预设变化阈值时,确认模型此时已经收敛;
选取模型收敛时相对变化小于预设变化阈值的损失值的取值范围的预设百分位数作为容易样本阈值范围。
7.根据权利要求2所述的方法,其中,使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,基于所述判断结果,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集,包括:
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练的第一次迭代训练阶段,判断所述原始训练集中的每个样本的进一步筛选出最终的一个或多个类型的损失的损失值是否在对应的容易样本阈值范围之内;
若某个样本的某个进一步筛选出最终的一个或多个类型的损失的损失值在对应的容易样本阈值范围之内,则将该样本确定为容易样本;
对所述容易样本的当前迭代训练阶段的次数计数,以得到容易样本的计数值,判断容易样本的计数值是否超过训练迭代次数的次数阈值N;
若容易样本的计数值未超过训练迭代次数的次数阈值N,则使用原始训练集对所述模型进行的当前轮迭代训练的下一次的迭代训练后,对该容易样本重复执行所述对所述容易样本的当前迭代训练阶段的次数计数的步骤;;
若容易样本的计数值超过训练迭代次数的次数阈值N,将该容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集。
8.一种基于样本剔除的模型训练设备,其中,该设备包括:
统计组装置,用于使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
范围确定装置,用于基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
剔除装置,用于使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
训练装置,用于使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
9.一种基于计算的设备,其中,包括:
处理器;以及
被安排成存储计算机可执行指令的存储器,所述可执行指令在被执行时使所述处理器:
使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
10.一种计算机可读存储介质,其上存储有计算机可执行指令,其中,该计算机可执行指令被处理器执行时使得该处理器:
使用原始训练集中的样本对模型进行一轮迭代训练直至模型收敛,其中,每一轮迭代训练包括多次迭代训练;同时在当前轮的迭代训练的各次迭代训练阶段获取所述原始训练集中所有样本的损失值,并对原始训练集中所有样本的损失值进行统计学分析,得到统计学分析结果;
基于所述统计学分析结果,确定所述原始训练集中的样本的各类型的损失对应的损失值的容易样本阈值范围;
使用原始训练集对所述模型进行的新的一轮迭代训练,在当前轮的迭代训练中,判断所述原始训练集中的每个样本的各类型的损失的对应的损失值是否在相应的容易样本阈值范围之内,以得到所述原始训练集中容易样本,将容易样本从所述原始训练集中剔除,将原始训练集中剩余的样本作为新的训练集;
使用新的训练集对所述模型进行的当前轮迭代训练的后续的迭代训练,直至模型收敛。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011493895.9A CN112633515A (zh) | 2020-12-16 | 2020-12-16 | 基于样本剔除的模型训练方法及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011493895.9A CN112633515A (zh) | 2020-12-16 | 2020-12-16 | 基于样本剔除的模型训练方法及设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112633515A true CN112633515A (zh) | 2021-04-09 |
Family
ID=75316311
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011493895.9A Pending CN112633515A (zh) | 2020-12-16 | 2020-12-16 | 基于样本剔除的模型训练方法及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112633515A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113141363A (zh) * | 2021-04-22 | 2021-07-20 | 西安交通大学 | 一种加密流量样本筛选方法、系统、设备及可读存储介质 |
CN114121204A (zh) * | 2021-12-09 | 2022-03-01 | 上海森亿医疗科技有限公司 | 基于患者主索引的患者记录匹配方法、存储介质及设备 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20170300811A1 (en) * | 2016-04-14 | 2017-10-19 | Linkedin Corporation | Dynamic loss function based on statistics in loss layer of deep convolutional neural network |
CN111160161A (zh) * | 2019-12-18 | 2020-05-15 | 电子科技大学 | 一种基于噪声剔除的自步学习人脸年龄估计方法 |
CN111160406A (zh) * | 2019-12-10 | 2020-05-15 | 北京达佳互联信息技术有限公司 | 图像分类模型的训练方法、图像分类方法及装置 |
CN111753914A (zh) * | 2020-06-29 | 2020-10-09 | 北京百度网讯科技有限公司 | 模型优化方法和装置、电子设备及存储介质 |
-
2020
- 2020-12-16 CN CN202011493895.9A patent/CN112633515A/zh active Pending
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20170300811A1 (en) * | 2016-04-14 | 2017-10-19 | Linkedin Corporation | Dynamic loss function based on statistics in loss layer of deep convolutional neural network |
CN111160406A (zh) * | 2019-12-10 | 2020-05-15 | 北京达佳互联信息技术有限公司 | 图像分类模型的训练方法、图像分类方法及装置 |
CN111160161A (zh) * | 2019-12-18 | 2020-05-15 | 电子科技大学 | 一种基于噪声剔除的自步学习人脸年龄估计方法 |
CN111753914A (zh) * | 2020-06-29 | 2020-10-09 | 北京百度网讯科技有限公司 | 模型优化方法和装置、电子设备及存储介质 |
Non-Patent Citations (2)
Title |
---|
YANYAO SHEN 等: "Learningwith bad training data via interatIve trimmed loss minimization", 《ARXIV:1810.11874V2[CS.LG]》, 18 February 2019 (2019-02-18), pages 1 - 30 * |
王学军;王文剑;曹飞龙;: "基于自步学习的加权稀疏表示人脸识别方法", 计算机应用, no. 11, 10 November 2017 (2017-11-10), pages 3145 - 3151 * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113141363A (zh) * | 2021-04-22 | 2021-07-20 | 西安交通大学 | 一种加密流量样本筛选方法、系统、设备及可读存储介质 |
CN114121204A (zh) * | 2021-12-09 | 2022-03-01 | 上海森亿医疗科技有限公司 | 基于患者主索引的患者记录匹配方法、存储介质及设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
JP6771751B2 (ja) | リスク評価方法およびシステム | |
JP6751235B2 (ja) | 機械学習プログラム、機械学習方法、および機械学習装置 | |
CN110675399A (zh) | 屏幕外观瑕疵检测方法及设备 | |
CN109544166A (zh) | 一种风险识别方法和装置 | |
EP3712825A1 (en) | Model prediction method and device | |
CN112633515A (zh) | 基于样本剔除的模型训练方法及设备 | |
CN111160950B (zh) | 一种资源信息的处理、输出方法及装置 | |
CN109145981B (zh) | 深度学习自动化模型训练方法及设备 | |
CN111176565B (zh) | 确定应用的存储负载的方法和设备 | |
CN110827246A (zh) | 电子设备边框外观瑕疵检测方法及设备 | |
CN110334012B (zh) | 一种风险评估方法及装置 | |
CN112434717B (zh) | 一种模型训练方法及装置 | |
CN111275106A (zh) | 对抗样本生成方法、装置及计算机设备 | |
CN110852443A (zh) | 特征稳定性检测方法、设备及计算机可读介质 | |
CN107832271B (zh) | 函数图像绘制方法、装置、设备及计算机存储介质 | |
CN107274043B (zh) | 预测模型的质量评价方法、装置及电子设备 | |
CN114697127B (zh) | 一种基于云计算的业务会话风险处理方法及服务器 | |
CN114172705B (zh) | 基于模式识别的网络大数据分析方法和系统 | |
CN107886113B (zh) | 一种基于卡方检验的电磁频谱噪声提取和滤波方法 | |
CN114897723A (zh) | 一种基于生成式对抗网络的图像生成与加噪方法 | |
CN111078877B (zh) | 数据处理、文本分类模型的训练、文本分类方法和装置 | |
CN113032553A (zh) | 信息处理装置和信息处理方法 | |
CN110751197A (zh) | 图片分类方法、图片模型训练方法及设备 | |
CN111767980A (zh) | 模型优化方法、装置及设备 | |
CN115953248B (zh) | 基于沙普利可加性解释的风控方法、装置、设备及介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |