CN113920369A - 一种模型训练方法、装置和电子设备 - Google Patents
一种模型训练方法、装置和电子设备 Download PDFInfo
- Publication number
- CN113920369A CN113920369A CN202111228575.5A CN202111228575A CN113920369A CN 113920369 A CN113920369 A CN 113920369A CN 202111228575 A CN202111228575 A CN 202111228575A CN 113920369 A CN113920369 A CN 113920369A
- Authority
- CN
- China
- Prior art keywords
- data
- data set
- training
- target
- unlabeled
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 84
- 238000013145 classification model Methods 0.000 claims abstract description 156
- 238000004590 computer program Methods 0.000 claims description 16
- 238000002372 labelling Methods 0.000 description 8
- 238000010586 diagram Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 4
- 238000004364 calculation method Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000002159 abnormal effect Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q40/00—Finance; Insurance; Tax strategies; Processing of corporate or income taxes
- G06Q40/02—Banking, e.g. interest calculation or account maintenance
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Business, Economics & Management (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Finance (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Accounting & Taxation (AREA)
- Evolutionary Biology (AREA)
- Development Economics (AREA)
- Economics (AREA)
- Marketing (AREA)
- Strategic Management (AREA)
- Technology Law (AREA)
- General Business, Economics & Management (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本申请提供了一种模型训练方法、装置和电子设备,所述模型训练方法包括:获取未标注数据集,所述未标注数据集包括多个未标注数据;将多个未标注数据输入分类模型,得到分类模型输出的多个分类结果,其中,每个未标注数据对应一个分类结果,分类结果用于表征对应的未标注数据的预测类别;基于多个分类结果,获取第一数据集;获取与第一数据集对应的第一目标数据集,第一目标数据集为对第一数据集中的每个未标注数据标注真实类别之后,得到的数据集;基于第一目标数据集,对分类模型进行迭代训练,得到目标分类模型。本申请提供的一种模型训练方法、装置和电子设备,可以解决由于样本数量不平衡而导致的模型精度较低的问题。
Description
技术领域
本申请涉及数据处理领域,具体涉及一种模型训练方法、装置和电子设备。
背景技术
目前,在对分类模型进行训练过程中,在某些场景下,所能获取到的不同类别的训练数据的数量之间可能差别相当大。例如,在银行信用欺诈交易识别中,所获取到的历史交易数据中,属于欺诈交易的样本通常仅占很少一部分,绝大部分为正常交易的样本。在此情况下,由于训练数据中各类型的样本数量不平衡,所训练得到的模型倾向于将待识别数据分类至样本数量较多的类别,从而可能导致模型精度较低的问题。
发明内容
本申请提供的一种模型训练方法、装置和电子设备,可以解决由于样本数量不平衡而导致的模型精度较低的问题。
第一方面,本申请实施例提供了一种模型训练方法,包括:
获取未标注数据集,所述未标注数据集包括多个未标注数据;
将所述多个未标注数据输入分类模型,得到所述分类模型输出的多个分类结果,其中,每个所述未标注数据对应一个所述分类结果,所述分类结果用于表征对应的所述未标注数据的预测类别,所述预测类别为预设类别中的类别;
基于所述多个分类结果,获取第一数据集,其中,所述第一数据集包括至少两个未标注数据组,且未标注数据组中的未标注数据的预测类别为未标注数据组对应的预设类别;
获取与所述第一数据集对应的第一目标数据集,所述第一目标数据集为对所述第一数据集中的每个未标注数据标注真实类别之后,得到的数据集;
基于所述第一目标数据集,对所述分类模型进行迭代训练,得到目标分类模型。
第二方面,本申请实施例提供了一种模型训练装置,包括:
第一获取模块,用于获取未标注数据集,所述未标注数据集包括多个未标注数据;
预测模块,用于将所述多个未标注数据输入分类模型,得到所述分类模型输出的多个分类结果,其中,每个所述未标注数据对应一个所述分类结果,所述分类结果用于表征对应的所述未标注数据的预测类别,所述预测类别为预设类别中的类别;
第二获取模块,用于基于所述多个分类结果,获取第一数据集,其中,所述第一数据集包括至少两个未标注数据组,且未标注数据组中的未标注数据的预测类别为未标注数据组对应的预设类别;
第三获取模块,用于获取与所述第一数据集对应的第一目标数据集,所述第一目标数据集为对所述第一数据集中的每个未标注数据标注真实类别之后,得到的数据集;
训练模块,用于基于所述第一目标数据集对所述分类模型进行迭代训练,得到目标分类模型。
第三方面,本申请实施例还提供了一种电子设备,包括处理器、存储器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现如上述第一方面所述的方法步骤。
第四方面,本申请实施例还提供了一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如上述第一方面所述的方法步骤。
本申请实施例中,通过基于分类模型对未标注数据进行分类,从而得到每个未标注数据的预测类别,这样,可以根据预测类别对未标注数据集中的未标注数据进行分类,然后,从每个类别中获取未标注数据形成第一数据集,并在确定第一数据集中每个未标注数据的真实类别之后,得到第一目标数据集,最后,基于第一目标数据集对分类模型进行迭代训练,以得到目标分类模型。该过程中,通过先对未标注数据进行分类,这样,在获取第一目标数据集时,可以相对均衡的从各个类别中获取对应数量的未标注数据进行标注,从而可以使得所获取到的第一目标数据集中各个类别的训练数据的数量相对均衡,进而可以有效的缓解样本数量不平衡的问题,进而可以提高训练得到的目标分类模型的精度。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的模型训练方法的流程图之一;
图2是本申请实施例提供的模型训练方法的流程图之二;
图3是本申请实施例提供的模型训练装置的结构示意图之一;
图4是本申请实施例提供的模型训练装置的结构示意图之二。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
针对背景技术中的由于各类型的样本数量不平衡,而导致的训练得到的模型的精度较低的问题。相关技术中主要存在过采样法和欠采样法两种手段,对训练数据中各类型的样本进行平衡。其中,所述过采样法是根据样本标签少的样本的规律去生成更多该标签样本,这样使得数据趋向于平衡。然而,采用过采样法平衡样本存在如下缺陷:因为其是用少量样本生成更多的样本,或者不断使用少量样本。会导致模型对这部分少量样本过拟合。所述欠采用法是通过减少类别中数据较多的数据,从而让类别平衡。然而,采用欠采样法平衡样本存在如下缺陷:由于没有利用到数据集中的所有数据,有信息损失,有欠拟合的风险。可见,现有的样本平衡方法均存在各自的缺陷。
基于此,本申请实施例提供的模型训练方法通过先基于分类模型对待标注数据集中的待标注数据进行分类,然后,从每个类别中获取n个待标注数据,以获得训练数据集。该过程中,包含了主动学习的思想。表现在标注过程中,就是主动去挑选“有价值”的数据进行标注。从而使最终结果优于随机选择。
本申请实施例提供的模型训练方法具体可以应用于:在分类模型的训练过程中,所能获取到的不同类别的训练数据的数量之间可能差别相当大。例如,由于在银行信用欺诈交易识别中,所获取到的历史交易数据中,属于欺诈交易的样本通常仅占很少一部分,绝大部分为正常交易的样本。因此,可以基于本申请实施例所提供的方法对训练银行信用欺诈交易识别模型进行训练。此外,还可以采用本申请实施例提供的方法对用户分类模型进行训练,其中,所述用户分类模型用于对全部用户进行分类,如分类为正常用户和非正常用户,还可包括其他分类结果。
请参见图1,为本申请实施例提供的一种模型训练方法,包括:
步骤101、获取未标注数据集,所述未标注数据集包括多个未标注数据;
步骤102、将所述多个未标注数据输入分类模型,得到所述分类模型输出的多个分类结果,其中,每个所述未标注数据对应一个所述分类结果,所述分类结果用于表征对应的所述未标注数据的预测类别,所述预测类别为预设类别中的类别;
步骤103、基于所述多个分类结果,获取第一数据集,其中,所述第一数据集包括至少两个未标注数据组,且未标注数据组中的未标注数据的预测类别为未标注数据组对应的预设类别;
步骤104、获取与所述第一数据集对应的第一目标数据集,所述第一目标数据集为对所述第一数据集中的每个未标注数据标注真实类别之后,得到的数据集;
步骤105、基于所述第一目标数据集对所述分类模型进行迭代训练,得到目标分类模型。
其中,所述未标注数据集可以是特定场景下的大量待标注数据,本申请的目的在于训练得到目标分类模型,然后,基于目标分类模型对该场景下的未标注数据进行标注,即确定所述未标注数据的类别标签,从而实现对该场景下的未标注数据进行分类的作用。
上述分类模型可以是预先构建的分类模型,也可以预先对该分类模型进行初步训练,得到该分类模型具有一定的分类能力,但其分类结果的准确率相对较低,因此,可以基于本申请的方法对其进行进一步训练,以提高其分类的准确率。
上述预设类别的数量可以为至少两个,所述预测类别为所述至少两个预设类别中的任意预设类别。所述预设类别即在进行训练之前人为设定的类别。例如,当需要对某一批未标注数据进行二分类时,可以预先设置类别A和类别B作为预设类别,然后,分别判断每个未标注数据属于类别A和类别B中的哪一类别。相应地,上述预测类别为分类模型对未标注数据的类别进行预测得到的类别,该预测类别是否为对应未标注数据的真实类别取决于预测结果的准确性。所述真实类别即未标注数据真实的类别,其中,真实类别可以是根据预设规则进行分类得到的类别。
上述分类结果可以表示对应的待分类数据的预测类别,相应地,上述基于所述多个分类结果,获取第一数据集可以是指:基于所述多个分类结果,从所述未标注数据集中获取第一数据集。此外,所述多个分类结果还可以是指:将所述未标注数据集分类为多个子集合,其中,每个子集合中的未标注数据的预测类别相同,此时,所述基于所述多个分类结果,获取第一数据集可以是指:从每个子集合中分别获取未标注数据,得到所述第一数据集。
上述基于所述多个分类结果,获取第一数据集具体可以是从每个预设类别中获取相同或将近数量的未标注数据,以形成所述第一数据集。这样,可以确保第一数据集中各个类别的未标注数据的数量相对均衡,进而可以确保基于第一数据集生成的第一目标数据集中各个类别的标注数据的数量相对均衡,以克服样本不平衡的问题。
下文以分别从每个预设类别中获取相同数量的未标注数据,即从每个预设类别中获取n个未标注数据,以生成所述第一数据集为例,对本申请实施例提供的方法作进一步的解释说明。可以理解的是,上述分类模型对未标注数据进行识别输出的分类结果可能为正确的分类结果,也可能为错误的分类结果。假设所述分类模型分类的准确率为10%,所述未标注数据集包括10000条数据,且10000条数据分别属于A和B两个预设类别,其中,10000条数据中,A类别的数据量与B类别的数据量的比值为1:100。若采用现有技术中的方法随机从未标注数据集中获取一定数量(例如100条)的未标注数据进行人工标注,以得到训练数据集进行训练,则训练数据集中A类别的数据量与B类别的数据量的比值为1:99,存在较为严重的样本数量不平衡的问题。
而若先将所述10000条数据输入所述分类模型,获得每个待标注数据的预测类别之后,分别从A类别和B类别中获取50条数据,即上述n的取值为50,由于模型预测的准确率为10%,因此,所获取得到的50条A类别的数据中,实际为A类别的数据量可能为5条,在此情况下,所述训练数据集中,A类别的数据量与B类别的数据量的比值可能为:5:95,显然,相对于随机选取数据而言,可以有效的降低样本数量不平衡的问题,因此,相对于采用现有的方法而言,采用本申请实施例提供的模型训练方法能够提高训练得到的分类模型的分类准确率。
在具体实施时,可以采用本申请提供的方法对分类模型进行迭代训练。由于每次训练过程中的分类模型可以进一步学习到各个类别的数据之间存在的区别,以进一步提高分类模型对各个类别的数据的区分效果。因此,每次完成训练之后得到的分类模型相对于训练之前的分类模型的分类准确率理论上会有所提高。具体而言,假设在经过一次训练之后,所述分类模型的准确率由10%提高到20%,在下一次训练过程中,将所述10000条数据输入训练后得到的所述分类模型,获得每个待标注数据的预测类别之后,分别从A类别和B类别中获取50条数据,由于模型预测的准确率为20%,因此,所获取得到的50条A类别的数据中,实际为A类别的数据量可能为10条,在此情况下,所述训练数据集中,A类别的数据量与B类别的数据量的比值可能为:10:90,显然,在经过一次迭代之后,在下一次训练过程中,可以进一步降低样本数量不平衡的问题,从而有利于进一步提高分类模型分类的准确性。如此,随着迭代次数的增加,分类模型分类的准确性不断提高,在基于分类模型的预测结果获取到的第一目标数据集中各类别的数据的数量也将趋于均衡,从而可以进一步缓解由于样本数量不平衡而导致的模型精度较低的问题。
该实施方式中,通过基于分类模型对未标注数据进行分类,从而得到每个未标注数据的预测类别,这样,可以根据预测类别对未标注数据集中的未标注数据进行分类,然后,从每个类别中获取未标注数据形成第一数据集,并在确定第一数据集中每个未标注数据的真实类别之后,得到第一目标数据集,最后,基于第一目标数据集对分类模型进行迭代训练,以得到目标分类模型。该过程中,通过先对未标注数据进行分类,这样,在获取第一目标数据集时,可以相对均衡的从各个类别中获取对应数量的未标注数据进行标注,从而可以使得所获取到的第一目标数据集中各个类别的训练数据的数量相对均衡,进而可以有效的缓解样本数量不平衡的问题,进而可以提高训练得到的目标分类模型的精度。
可选地,所述基于所述第一目标数据集对所述分类模型进行迭代训练,得到目标分类模型,包括:
所述迭代训练共进行I次训练,其中,
所述迭代训练中的第i次训练包括:
取目标训练数据集与所述第一目标数据集的并集,得到第i组训练数据集;
基于所述第i组训练数据集对所述分类模型进行训练,得到第i个分类模型;
其中,在所述i等于1的情况下,所述分类模型为预先构建的初始分类模型,所述目标训练数据集为预先构建的初始训练数据集;
在所述i不等于1的情况下,所述分类模型为第i-1分类模型,所述目标训练数据集为第i-1组训练数据集。
具体地,进行模型训练之前,可以预先构建初始分类模型,同时获取初始训练数据集,所述初始训练数据集可以包括所述至少两个预设类别中的每个预设类别对应的预设个标注数据,例如,所述可以认为获取每个预设类别对应的50条标注数据,以形成所述初始训练数据集,其中,所述标注数据为具有真实类别标签的数据。
然后,获取未标注数据集,将未标注数据集输入对初始分类模型,得到对初始分类模型输出的每个未标注数据的预测类别之后,基于上述实施例所述的方法得到第一目标数据集1,然后,将第一目标数据集1与初始训练数据集取并集,得到第1组训练数据集,基于第1组训练数据集对初始分类模型进行训练,得到训练之后的第1分类模型。
然后,再次将未标注数据集输入第1分类模型,得到第1分类模型输出的每个未标注数据的预测类别之后,可以基于第1分类模型的输出结果,判断第1分类模型是否满足收敛条件,当满足收敛条件时,将第1分类模型确定为目标分类模型,结束模型训练过程。其中,所述收敛条件可以是指第1分类模型输出的分类结果的准确率是否达到预设值,例如,当所述第1分类模型输出的分类结果的准确率高于95%时,确定模型收敛。
当所述第1分类模型不满足所述收敛条件时,可以基于所述第1分类模型输出的预测结果,获取第一目标数据集2,可以理解的是,基于初始分类模型输出的预测结果,获取的第一目标数据集1与基于第1分类模型输出的预测结果,获取的第一目标数据集2为不同的数据集。然后,将第一目标数据集2与第1组训练数据集取并集,得到第2训练数据集,基于第2训练数据集对第1分类模型进行训练,得到训练之后的第2分类模型。
然后,再次将未标注数据集输入第2分类模型,得到第2分类模型输出的每个未标注数据的预测类别之后,判断第2分类模型是否收敛,在第2分类模型收敛的情况下,将第2分类模型确定为目标分类模型。在第2分类模型不收敛的情况下,基于上述方法进一步迭代训练,直至训练得到的模型满足收敛条件。
其中,所述收敛条件可以是指训练迭代训练的次数达到上述I次。此外,所述收敛条件还可以是指训练得到的目标分类模型的精度达到预设精度。
可选地,所述基于所述多个分类结果,获取第一数据集,包括:
基于所述分类结果对所述多个未标注数据进行分类,得到至少两个子集合,其中,一个子集合对应一个预设类别;
从每个子集合中,获取m个第一未标注数据和k个第二未标注数据,所述第一未标注数据为所述分类结果预测的准确率小于第一阈值的未标注数据,所述第二未标注数据为所述分类结果预测的准确率大于或等于所述第一阈值的未标注数据;
将每个子集合中的所述m个第一未标注数据和所述k个第二未标注数据确定为所述第一数据集中的数据。
当上述第一数据集为从每个预设类别中获取n个未标注数据,得到的数据集时,所述m与所述k之和为所述n。
其中,所述基于所述分类结果对所述多个未标注数据进行分类,可以是至按照每个未标注数据对应的预测类别对所述未标注数据集中的数据进行聚类,以形成所述至少两个子集合。例如,假设所述至少两个预设类别包括A类别和B类别,则可将所述未标注数据集中预测类别为A类别的全部未标注数据划分至子集A合,以及,将所述未标注数据集中预测类别为B类别的全部未标注数据划分至子集合B,从而得到所述至少两个子集合。
上述分类模型在对每个未标注数据进行预测时,可以输出每个未标注数据属于所述至少两个预设类别中各个预设类别的概率,这样,可以基于所述分类模型的输出结果计算每个未标注数据的熵,所述未标注数据的熵越高,对应的分类结果的预测的准确率越低。其中,上述熵可以基于现有的熵的计算方法进行计算,例如,可以从采用如下公式计算每个未标注数据的熵:
-sum(pk*log(pk),axis=-1)
其中,pk包括未标注数据属于每个预设类别的概率值。
在计算得到每个未标注数据的熵之后,可以按照熵的大小,对每个子集合中的未标注数据进行排序,从而可以将每个子集合中的未标注数据分类为:准确率大于第一阈值的部分和准确率小于或等于第一阈值的部分。由于预测结果的准确率越高,预测类别的可靠性也就越高。这样,在模型训练的初期,由于分类模型自身的准确率较低,因此,在从每个子集合中获取n个未标注数据时,可以从准确率大于第一阈值的部分获取大量未标注数据,而从准确率小于或等于第一阈值的部分获取少量未标注数据,即使所述m小于所述k,例如,所述m的取值为0.2n,所述k的取值为0.8n。这样,可以确保所获取得到的第一数据集中,尽量多的包括少类别的未标注数据,以进一步缓解样本数量不平衡的问题。
可以理解的是,由于所述目标未标注数据的分类结果包括:目标未标注数据属于所述至少两个预设类别中各个预设类别的概率,因此,在确定目标未标注数据的预测类别时,即可将所述目标未标注数据的分类结果中,概率值最高的预设类别确定为所述目标未标注数据的预测类别。例如,目标未标注数据的分类结果为:属于类别A的概率为0.7、属于类别B的概率为0.5,则此时,可以将所述目标未标注数据的预测类别确定为类别A。
可选地,所述从每个子集合中,获取m个第一未标注数据和k个第二未标注数据之前,所述方法还包括:
基于所述多个分类结果确定所述分类模型的目标准确率;
在所述目标准确率小于第二阈值的情况下,所述m小于所述k;
在所述目标准确率大于或等于所述第二阈值的情况下,所述m大于所述k。
其中,上述基于所述多个分类结果确定所述分类模型的目标准确率的具体过程可以包括:从所述多个分类结果中随机采样s1个分类结果,然后,所述判断所述s1个分类结果中,分类正确的数量s2,利用s2除以s1即可得到所述分类模型的目标准确率。
具体地,可以根据目标准确率与第二阈值的相对大小,将模型训练分类不同的阶段,当所述目标准确率小于第二阈值的情况下,可以视为模型训练的初期;当所述目标准确率大于或等于所述第二阈值的情况下,可以视为模型训练的稳定期,其中,所述第二阈值可以根据模型收敛条件进行确定,例如,模型收敛条件为目标准确率大于95%时,则所述第二阈值的取值可以是75%-85%之间的任意一个数值。
在模型训练的初期,由于分类模型自身的准确率较低,因此,在从每个子集合中获取n个未标注数据时,可以从准确率大于第一阈值的部分获取大量未标注数据,而从准确率小于或等于第一阈值的部分获取少量未标注数据,即使所述m小于所述k,例如,所述m的取值为0.2n,所述k的取值为0.8n。这样,可以确保所获取得到的第一数据集中,尽量多的包括少类别的未标注数据,以进一步缓解样本数量不平衡的问题。
在模型训练的稳定期,由于模型的目标准确率可以达到75%以上,因此,所获取得到的第一目标数据集中各类别的待标注数据的数量相对均衡。此时,为了进一步提高模型的准确率,在本次训练时,可以向模型输入大量其识别结果准确率较低的训练数据,以加强模型对识别准确率较低的训练数据的学习,从而进一步提高训练得到的模型的精度。例如,在模型训练的稳定期,在从每个子集合中获取n个未标注数据时,可以从准确率大于第一阈值的部分获取少量未标注数据,而从准确率小于或等于第一阈值的部分获取大量未标注数据,即使所述m大于所述k,例如,所述m的取值为0.8n,所述k的取值为0.2n。
可选地,所述分类模型包括与所述预设类别对应的权重参数,所述基于所述第i组训练数据集对所述分类模型进行训练,得到第i个分类模型之前,所述方法还包括:
基于所述第i组训练数据集对所述分类模型的所述权重参数进行更新,其中,所述第i组训练数据集中,目标类别对应的标注数据的数量越多,所述目标类别的权重参数越小,所述目标类别为所述预设类别中的任意预设类别。
具体地,在所述预设类别的数量为至少两个预设类别的情况下,所述分类模型包括与所述至少两个预设类别一一对应的至少两个权重参数,所述基于所述第i组训练数据集对所述分类模型进行训练,得到第i个分类模型之前,所述方法还包括:基于所述第i组训练数据集对所述分类模型的所述至少两个权重参数进行更新,其中,所述第i组训练数据集中,目标类别对应的标注数据的数量越多,所述目标类别的权重参数越小,所述目标类别为所述至少两个预设类别中的任意预设类别。
由于第i组训练数据集中,各类别的训练数据的数量可能不同,为平衡样本,当某一预设类别在第i组训练数据集中的标注数据的数量越多时,可以减小该预设类别的权重,相应地,当某一预设类别在第i组训练数据集中的标注数据的数量越少时,可以增大该预设类别的权重。从而可以进一步缓解由于样本数量不平衡,而导致的训练得到的模型的精度低的问题。
在本申请一个实施例中,具体可以采用如下公式计算每个预设类别的权重值:
Q1=(1/L1)*(L/r)
其中,所述Q1为目标类别的权重值,所述L1为第i组训练数据集中,标注数据为目标类别的数量,L为第i组训练数据集中标注数据的总数量,r为所述预设类别的数量。
请参见图2,为本申请实施例提供的一种具体的模型训练方法,所述方法包括如下步骤:获取未标注数据集,从未标注数据集中获取一定数量的未标注数据,并对所获取的未标注数据进行标注,得到初始数据集(即标注数据集),然后,将未标注数据集输入模型进行训练,得到模型输出的预测结果,基于模型输出的预测结果,计算每个未标注数据的熵,同时,计算模型的目标准确率,在模型的目标准确率低于第二阈值的情况下,确定模型处于训练初期,此时,取m=0.2n,k=0.8n按照上述实施例中的方法,获取第一目标数据集,然后,将第一目标数据集添加至所述标注数据集,得到训练数据集对模型进行训练。相应地,在模型的目标准确率高于或等于第二阈值的情况下,确定模型处于稳定期,此时,取m=0.8n,k=0.2n按照上述实施例中的方法,获取第一目标数据集,然后,将第一目标数据集添加至所述标注数据集,得到训练数据集对模型进行训练,直至训练得到的模型收敛,将最终的模型作为目标分类模型进行输出。
具体地,在模型训练的初期,由于分类模型自身的准确率较低,因此,在从每个子集合中获取n个未标注数据时,可以从准确率大于第一阈值的部分获取大量未标注数据,而从准确率小于或等于第一阈值的部分获取少量未标注数据,即使所述m小于所述k,例如,所述m的取值为0.2n,所述k的取值为0.8n。这样,可以确保所获取得到的第一数据集中,尽量多的包括少类别的未标注数据,以进一步缓解样本数量不平衡的问题。
在模型训练的稳定期,由于模型的目标准确率较高,因此,所获取得到的第一目标数据集中各类别的待标注数据的数量相对均衡。此时,为了进一步提高模型的准确率,在进行训练时,可以向模型输入大量其识别结果准确率较低的训练数据,以加强模型对识别准确率较低的训练数据的学习,从而进一步提高训练得到的模型的精度。例如,在模型训练的稳定期,在从每个子集合中获取n个未标注数据时,可以从准确率大于第一阈值的部分获取少量未标注数据,而从准确率小于或等于第一阈值的部分获取大量未标注数据,即使所述m大于所述k,例如,所述m的取值为0.8n,所述k的取值为0.2n。
本实施例提供您的方法的具体实现过程可以参见上述实施例,为避免重复,在此不再予以赘述。
请参见图3,为本申请实施例提供的一种模型训练装置300的结构示意图,所述装置包括:
第一获取模块301,用于获取未标注数据集,所述未标注数据集包括多个未标注数据;
预测模块302,用于将所述多个未标注数据输入分类模型,得到所述分类模型输出的多个分类结果,其中,每个所述未标注数据对应一个所述分类结果,所述分类结果用于表征对应的所述未标注数据的预测类别,所述预测类别为预设类别中的类别;
第二获取模块303,用于基于所述多个分类结果,获取第一数据集,其中,所述第一数据集包括至少两个未标注数据组,且未标注数据组中的未标注数据的预测类别为未标注数据组对应的预设类别;
第三获取模块304,用于获取与所述第一数据集对应的第一目标数据集,所述第一目标数据集为对所述第一数据集中的每个未标注数据标注真实类别之后,得到的数据集;
训练模块305,用于基于所述第一目标数据集对所述分类模型进行迭代训练,得到目标分类模型。
可选地,所述第二获取模块303,包括:
分类子模块,用于基于所述分类结果对所述多个未标注数据进行分类,得到至少两个子集合,其中,一个子集合对应一个预设类别;
第一获取子模块,用于从每个子集合中,获取m个第一未标注数据和k个第二未标注数据,所述第一未标注数据为所述分类结果预测的准确率小于第一阈值的未标注数据,所述第二未标注数据为所述分类结果预测的准确率大于或等于所述第一阈值的未标注数据;
第一确定子模块,用于将每个子集合中的所述m个第一未标注数据和所述k个第二未标注数据确定为所述第一数据集中的数据。
可选地,所述装置还包括:
确定模块,用于基于所述多个分类结果确定所述分类模型的目标准确率;
在所述目标准确率小于第二阈值的情况下,所述m小于所述k;
在所述目标准确率大于或等于所述第二阈值的情况下,所述m大于所述k。
可选地,所述分类结果包括目标未标注数据属于各个预设类别的概率,所述目标未标注数据为与所述分类结果对应的未标注数据,所述装置还包括:
计算模块,用于基于所述分类结果计算每个未标注数据的熵。
可选地,所述迭代训练共进行I次训练,其中,
所述迭代训练中的第i次训练包括:
取目标训练数据集与所述第一目标数据集的并集,得到第i组训练数据集;
基于所述第i组训练数据集对所述分类模型进行训练,得到第i个分类模型;
其中,在所述i等于1的情况下,所述分类模型为预先构建的初始分类模型,所述目标训练数据集为预先构建的初始训练数据集;
在所述i不等于1的情况下,所述分类模型为第i-1分类模型,所述目标训练数据集为第i-1组训练数据集。
可选地,所述装置还包括:
第四获取模块,用于获取所述初始训练数据集,所述初始训练数据集包括每个所述预设类别对应的预设个标注数据。
可选地,所述分类模型包括与所述预设类别对应的权重参数,所述装置还包括:
参数更新模块,用于基于所述第i组训练数据集对所述分类模型的所述权重参数进行更新,其中,所述第i组训练数据集中,目标类别对应的标注数据的数量越多,所述目标类别的权重参数越小,所述目标类别为所述预设类别中的任意预设类别。
本申请实施例提供的模型训练装置300能够实现上述模型训练方法实施例中的各个过程,为避免重复,这里不再赘述。
参见图4,图4是本申请另一实施提供的模型训练装置400的结构图,如图4所示,模型训练装置400包括:处理器401、存储器402及存储在所述存储器402上并可在所述处理器上运行的计算机程序,模型训练装置400中的各个组件通过总线接口403耦合在一起,所述计算机程序被所述处理器401执行时实现如下步骤:
获取未标注数据集,所述未标注数据集包括多个未标注数据;
将所述多个未标注数据输入分类模型,得到所述分类模型输出的多个分类结果,其中,每个所述未标注数据对应一个所述分类结果,所述分类结果用于表征对应的所述未标注数据的预测类别,所述预测类别为预设类别中的类别;
基于所述多个分类结果,获取第一数据集,其中,所述第一数据集包括至少两个未标注数据组,且未标注数据组中的未标注数据的预测类别为未标注数据组对应的预设类别;
获取与所述第一数据集对应的第一目标数据集,所述第一目标数据集为对所述第一数据集中的每个未标注数据标注真实类别之后,得到的数据集;
基于所述第一目标数据集,对所述分类模型进行迭代训练,得到目标分类模型。
可选地,所述基于所述多个分类结果,获取第一数据集,包括:
基于所述分类结果对所述多个未标注数据进行分类,得到至少两个子集合,其中,一个子集合对应一个预设类别;
从每个子集合中,获取m个第一未标注数据和k个第二未标注数据,所述第一未标注数据为所述分类结果预测的准确率小于第一阈值的未标注数据,所述第二未标注数据为所述分类结果预测的准确率大于或等于所述第一阈值的未标注数据;
将每个子集合中的所述m个第一未标注数据和所述k个第二未标注数据确定为所述第一数据集中的数据。
可选地,所述从每个子集合中,获取m个第一未标注数据和k个第二未标注数据之前,所述方法还包括:
基于所述多个分类结果确定所述分类模型的目标准确率;
在所述目标准确率小于第二阈值的情况下,所述m小于所述k;
在所述目标准确率大于或等于所述第二阈值的情况下,所述m大于所述k。
可选地,所述分类结果包括目标未标注数据属于各个预设类别的概率,所述目标未标注数据为与所述分类结果对应的未标注数据,所述从每个子集合中,获取m个第一未标注数据和k个第二未标注数据之前,还包括:
基于所述分类结果计算每个未标注数据的熵。
可选地,所述基于所述第一目标数据集对所述分类模型进行迭代训练,得到目标分类模型,包括:
所述迭代训练共进行I次训练,其中,
所述迭代训练中的第i次训练包括:
取目标训练数据集与所述第一目标数据集的并集,得到第i组训练数据集;
基于所述第i组训练数据集对所述分类模型进行训练,得到第i个分类模型;
其中,在所述i等于1的情况下,所述分类模型为预先构建的初始分类模型,所述目标训练数据集为预先构建的初始训练数据集;
在所述i不等于1的情况下,所述分类模型为第i-1分类模型,所述目标训练数据集为第i-1组训练数据集。
可选地,所述基于所述未标注数据集和所述第一目标数据集对所述分类模型进行迭代训练,得到所述目标分类模型之前,所述方法还包括:
获取所述初始训练数据集,所述初始训练数据集包括每个所述预设类别对应的预设个标注数据。
可选地,所述分类模型包括与所述预设类别对应的权重参数,所述基于所述第i组训练数据集对所述分类模型进行训练,得到第i个分类模型之前,所述方法还包括:
基于所述第i组训练数据集对所述分类模型的所述权重参数进行更新,其中,所述第i组训练数据集中,目标类别对应的标注数据的数量越多,所述目标类别的权重参数越小,所述目标类别为所述预设类别中的任意预设类别。
本申请实施例还提供一种电子设备,包括处理器,存储器,存储在存储器上并可在所述处理器上运行的计算机程序,该计算机程序被处理器执行时实现上述方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。
本申请实施例还提供一种计算机可读存储介质,计算机可读存储介质上存储有计算机程序,该计算机程序被处理器执行时实现上述方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。其中,所述的计算机可读存储介质,如只读存储器(Read-Only Memory,简称ROM)、随机存取存储器(Random Access Memory,简称RAM)、磁碟或者光盘等。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台电子设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本申请各个实施例所述的方法。
上面结合附图对本申请的实施例进行了描述,但是本申请并不局限于上述的具体实施方式,上述的具体实施方式仅仅是示意性的,而不是限制性的,本领域的普通技术人员在本申请的启示下,在不脱离本申请宗旨和权利要求所保护的范围情况下,还可做出很多形式,均属于本申请的保护之内。
Claims (10)
1.一种模型训练方法,其特征在于,包括:
获取未标注数据集,所述未标注数据集包括多个未标注数据;
将所述多个未标注数据输入分类模型,得到所述分类模型输出的多个分类结果,其中,每个所述未标注数据对应一个所述分类结果,所述分类结果用于表征对应的所述未标注数据的预测类别,所述预测类别为预设类别中的类别;
基于所述多个分类结果,获取第一数据集,其中,所述第一数据集包括至少两个未标注数据组,且未标注数据组中的未标注数据的预测类别为未标注数据组对应的预设类别;
获取与所述第一数据集对应的第一目标数据集,所述第一目标数据集为对所述第一数据集中的每个未标注数据标注真实类别之后,得到的数据集;
基于所述第一目标数据集,对所述分类模型进行迭代训练,得到目标分类模型。
2.根据权利要求1所述的方法,其特征在于,所述基于所述多个分类结果,获取第一数据集,包括:
基于所述分类结果对所述多个未标注数据进行分类,得到至少两个子集合,其中,一个子集合对应一个预设类别;
从每个子集合中,获取m个第一未标注数据和k个第二未标注数据,所述第一未标注数据为所述分类结果预测的准确率小于第一阈值的未标注数据,所述第二未标注数据为所述分类结果预测的准确率大于或等于所述第一阈值的未标注数据;
将每个子集合中的所述m个第一未标注数据和所述k个第二未标注数据确定为所述第一数据集中的数据。
3.根据权利要求2所述的方法,其特征在于,所述从每个子集合中,获取m个第一未标注数据和k个第二未标注数据之前,所述方法还包括:
基于所述多个分类结果确定所述分类模型的目标准确率;
在所述目标准确率小于第二阈值的情况下,所述m小于所述k;
在所述目标准确率大于或等于所述第二阈值的情况下,所述m大于所述k。
4.根据权利要求2所述的方法,其特征在于,所述分类结果包括目标未标注数据属于各个预设类别的概率,所述目标未标注数据为与所述分类结果对应的未标注数据,所述从每个子集合中,获取m个第一未标注数据和k个第二未标注数据之前,还包括:
基于所述分类结果计算每个未标注数据的熵。
5.根据权利要求1所述的方法,其特征在于,所述基于所述第一目标数据集对所述分类模型进行迭代训练,得到目标分类模型,包括:
所述迭代训练共进行I次训练,其中,
所述迭代训练中的第i次训练包括:
取目标训练数据集与所述第一目标数据集的并集,得到第i组训练数据集;
基于所述第i组训练数据集对所述分类模型进行训练,得到第i个分类模型;
其中,在所述i等于1的情况下,所述分类模型为预先构建的初始分类模型,所述目标训练数据集为预先构建的初始训练数据集;
在所述i不等于1的情况下,所述分类模型为第i-1分类模型,所述目标训练数据集为第i-1组训练数据集。
6.根据权利要求5所述的方法,其特征在于,所述基于所述未标注数据集和所述第一目标数据集对所述分类模型进行迭代训练,得到所述目标分类模型之前,所述方法还包括:
获取所述初始训练数据集,所述初始训练数据集包括每个所述预设类别对应的预设个标注数据。
7.根据权利要求5所述的方法,其特征在于,所述分类模型包括与所述预设类别对应的权重参数,所述基于所述第i组训练数据集对所述分类模型进行训练,得到第i个分类模型之前,所述方法还包括:
基于所述第i组训练数据集对所述分类模型的所述权重参数进行更新,其中,所述第i组训练数据集中,目标类别对应的标注数据的数量越多,所述目标类别的权重参数越小,所述目标类别为所述预设类别中的任意预设类别。
8.一种模型训练装置,其特征在于,包括:
第一获取模块,用于获取未标注数据集,所述未标注数据集包括多个未标注数据;
预测模块,用于将所述多个未标注数据输入分类模型,得到所述分类模型输出的多个分类结果,其中,每个所述未标注数据对应一个所述分类结果,所述分类结果用于表征对应的所述未标注数据的预测类别,所述预测类别为预设类别中的类别;
第二获取模块,用于基于所述多个分类结果,获取第一数据集,其中,所述第一数据集包括至少两个未标注数据组,且未标注数据组中的未标注数据的预测类别为未标注数据组对应的预设类别;
第三获取模块,用于获取与所述第一数据集对应的第一目标数据集,所述第一目标数据集为对所述第一数据集中的每个未标注数据标注真实类别之后,得到的数据集;
训练模块,用于基于所述第一目标数据集对所述分类模型进行迭代训练,得到目标分类模型。
9.一种电子设备,其特征在于,包括处理器、存储器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现如权利要求1至7中任一项所述的方法步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至7中任一项所述的方法步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111228575.5A CN113920369A (zh) | 2021-10-21 | 2021-10-21 | 一种模型训练方法、装置和电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111228575.5A CN113920369A (zh) | 2021-10-21 | 2021-10-21 | 一种模型训练方法、装置和电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113920369A true CN113920369A (zh) | 2022-01-11 |
Family
ID=79242174
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111228575.5A Pending CN113920369A (zh) | 2021-10-21 | 2021-10-21 | 一种模型训练方法、装置和电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113920369A (zh) |
-
2021
- 2021-10-21 CN CN202111228575.5A patent/CN113920369A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109872162B (zh) | 一种处理用户投诉信息的风控分类识别方法及系统 | |
CN113254643B (zh) | 文本分类方法、装置、电子设备和 | |
CN110930218B (zh) | 一种识别欺诈客户的方法、装置及电子设备 | |
CN111797320A (zh) | 数据处理方法、装置、设备及存储介质 | |
CN112950347A (zh) | 资源数据处理的优化方法及装置、存储介质、终端 | |
CN115130536A (zh) | 特征提取模型的训练方法、数据处理方法、装置及设备 | |
CN111091408A (zh) | 用户识别模型创建方法、装置与识别方法、装置 | |
CN114663002A (zh) | 一种自动化匹配绩效考核指标的方法及设备 | |
CN111582315A (zh) | 样本数据处理方法、装置及电子设备 | |
CN113011961B (zh) | 公司关联信息风险监测方法、装置、设备及存储介质 | |
CN114169439A (zh) | 异常通信号码的识别方法、装置、电子设备和可读介质 | |
CN115423600B (zh) | 数据筛选方法、装置、介质及电子设备 | |
CN111143533A (zh) | 一种基于用户行为数据的客服方法及系统 | |
CN110717817A (zh) | 贷前审核方法及装置、电子设备和计算机可读存储介质 | |
CN111401675A (zh) | 基于相似性的风险识别方法、装置、设备及存储介质 | |
CN113920369A (zh) | 一种模型训练方法、装置和电子设备 | |
CN110570301B (zh) | 风险识别方法、装置、设备及介质 | |
CN109308565B (zh) | 人群绩效等级识别方法、装置、存储介质及计算机设备 | |
CN113569957A (zh) | 一种业务对象的对象类型识别方法、装置及存储介质 | |
CN113850670A (zh) | 银行产品推荐方法、装置、设备及存储介质 | |
CN113724700A (zh) | 语种识别、语种识别模型训练方法及装置 | |
CN112632229A (zh) | 文本聚类方法及装置 | |
CN116187299B (zh) | 一种科技项目文本数据检定评价方法、系统及介质 | |
CN113535805B (zh) | 数据挖掘方法及相关装置和电子设备、存储介质 | |
CN111881287B (zh) | 一种分类模糊性分析方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |