CN110659657A - 训练模型的方法和装置 - Google Patents
训练模型的方法和装置 Download PDFInfo
- Publication number
- CN110659657A CN110659657A CN201810695632.2A CN201810695632A CN110659657A CN 110659657 A CN110659657 A CN 110659657A CN 201810695632 A CN201810695632 A CN 201810695632A CN 110659657 A CN110659657 A CN 110659657A
- Authority
- CN
- China
- Prior art keywords
- sample
- training
- discriminator
- value
- generator
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 200
- 238000000034 method Methods 0.000 title claims abstract description 46
- 230000000694 effects Effects 0.000 claims abstract description 41
- 238000013256 Gubra-Amylin NASH model Methods 0.000 claims abstract description 40
- 238000002372 labelling Methods 0.000 claims abstract description 15
- 238000011156 evaluation Methods 0.000 claims description 52
- 238000012360 testing method Methods 0.000 claims description 23
- 238000004422 calculation algorithm Methods 0.000 claims description 10
- 238000004590 computer program Methods 0.000 claims description 9
- 238000007637 random forest analysis Methods 0.000 claims description 8
- 238000013441 quality evaluation Methods 0.000 description 26
- 238000010586 diagram Methods 0.000 description 17
- 230000008569 process Effects 0.000 description 15
- 238000004891 communication Methods 0.000 description 7
- 230000006870 function Effects 0.000 description 6
- 238000012545 processing Methods 0.000 description 5
- 230000003042 antagnostic effect Effects 0.000 description 4
- 230000008859 change Effects 0.000 description 4
- 238000013136 deep learning model Methods 0.000 description 4
- 230000003287 optical effect Effects 0.000 description 4
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000008901 benefit Effects 0.000 description 2
- 239000000835 fiber Substances 0.000 description 2
- 230000000644 propagated effect Effects 0.000 description 2
- 239000004065 semiconductor Substances 0.000 description 2
- 238000010276 construction Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 230000003203 everyday effect Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2413—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on distances to training or reference patterns
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种训练模型的方法和装置,涉及计算机技术领域。其中,该方法包括:基于训练样本对GAN模型中的生成器和判别器进行训练;其中,所述训练样本为人工标记样本;通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本;基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。通过以上步骤,能够减少模型训练对人工标注数据的依赖,并且能够让模型根据训练效果自我优化。
Description
技术领域
本发明涉及计算机技术领域,尤其涉及一种训练模型的方法和装置。
背景技术
目前,在对文本数据进行分类处理时,往往是由专家提供标注数据,再通过标注数据训练模型进行分类预测。图片数据的标注比较客观,人人都可以标注。但是,文本数据的标注比较复杂,依赖专家的主观判别,因此对专家标注具有较强的依赖性。
在实现本发明过程中,发明人发现现有技术中至少存在如下问题:第一、依赖专家人工标记样本数据,不仅工作量大而且耗时。第二、如果标注数据量过少,可能会导致模型泛化能力不够或者在训练样本上过拟合。第三、用户活跃的平台每天都能产生大量的评论,而这些评论的分布是不均匀的。以商品评价为例,很可能低质量评价占商品评价的绝大多数,那么标注高质量评价的成本就会暴增。
发明内容
有鉴于此,本发明提供一种训练模型的方法和装置,能够减少模型训练对人工标注数据的依赖,并且能够让模型根据训练效果自我优化。
为实现上述目的,根据本发明的一个方面,提供了一种训练模型的方法。
本发明的训练模型的方法包括:基于训练样本对GAN模型中的生成器和判别器进行训练;其中,所述训练样本为人工标记样本;通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本;基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
可选地,所述方法还包括:在执行所述基于训练样本对GAN模型中的生成器和判别器进行训练的步骤之后,通过所述判别器对测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第一取值;其中,所述测试样本为人工标记样本;在执行所述基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练的步骤之后,通过所述判别器对所述测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第二取值;以及,在第二取值高于第一取值的情况下,再次执行所述通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本的步骤。
可选地,所述方法还包括:在第二取值低于或等于第一取值的情况下,向所述训练样本中追加人工标记样本。
可选地,所述预测效果评估参数包括以下至少一项:准确率、ROC、AUC。
可选地,所述判别器采用随机森林算法。
为实现上述目的,根据本发明的另一方面,提供了一种训练模型的装置。
本发明的训练模型的装置包括:训练模块,用于基于训练样本对GAN模型中的生成器和判别器进行训练;所述训练样本为人工标记样本;样本添加模块,用于通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本;所述训练模块,还用于基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
可选地,所述装置还包括:评估模块,用于在所述训练模块基于训练样本对GAN模型中的生成器和判别器进行训练之后,通过所述判别器对测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第一取值;其中,所述测试样本为人工标记样本;所述评估模块,还用于在所述训练模块基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练之后,通过所述判别器对所述测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第二取值;所述样本添加模块,还用于在第二取值高于第一取值的情况下,再次执行所述通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本的步骤。
可选地,所述样本添加模块,还用于在第二取值低于或等于第一取值的情况下,向所述训练样本中追加人工标记样本。
为实现上述目的,根据本发明的再一个方面,提供了一种电子设备。
本发明的电子设备,包括:一个或多个处理器;以及,存储装置,用于存储一个或多个程序;当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现本发明的训练模型的方法。
为实现上述目的,根据本发明的又一个方面,提供了一种计算机可读介质。
本发明的计算机可读介质,其上存储有计算机程序,所述程序被处理器执行时实现本发明的训练模型的方法。
上述发明中的一个实施例具有如下优点或有益效果:通过基于判别器对无标记样本和/或由生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至训练样本,并基于新的训练样本对GAN模型中的生成器和判别器进行再训练等步骤,能够减少模型训练对人工标注数据(尤其是专家标注数据)的依赖,并且能够让模型根据训练效果自我优化。
上述的非惯用的可选方式所具有的进一步效果将在下文中结合具体实施方式加以说明。
附图说明
附图用于更好地理解本发明,不构成对本发明的不当限定。其中:
图1是根据本发明一个实施例的训练模型的方法的主要步骤示意图;
图2是根据本发明另一实施例的训练模型的方法的主要步骤示意图;
图3是GAN模型训练的原理示意图;
图4是根据本发明一个实施例的训练模型的装置的主要模块示意图;
图5是根据本发明另一实施例的训练模型的装置的主要模块示意图;
图6是本发明实施例可以应用于其中的示例性系统架构图;
图7是适于用来实现本发明实施例的电子设备的计算机系统的结构示意图。
具体实施方式
以下结合附图对本发明的示范性实施例做出说明,其中包括本发明实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本发明的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
需要指出的是,在不冲突的情况下,本发明中的实施例以及实施例中的特征可以相互组合。
图1是根据本发明一个实施例的训练模型的方法的主要步骤示意图。如图1所示,本发明实施例的训练模型的方法包括:
步骤S101、基于训练样本对GAN模型中的生成器和判别器进行训练。
其中,所述训练样本为人工标记样本。具体实施时,可由专家提供包含正负样本的数据。以商品评价为例,可由专家对商品评价进行人工打标,比如将商品评价人工标记为高质量评价(正样本)和低质量评价(负样本)。在得到人工标记样本之后,再将其分割成训练样本和测试样本。
GAN(生成式对抗网络)模型是一种深度学习模型,其包括生成器和判别器。在GAN模型的训练过程中,生成器主要用来学习输入的训练样本(比如文本、图像数据)的真实分布,从而让自身生成的伪样本更加贴近真实的训练样本,而判别器则努力地识别输入的样本的真假。这个过程相当于一个二人博弈的过程,随着时间的推移,生成器和判别器不断地进行对抗式训练,最终两个网络达到了一个动态平衡:生成器生成的伪样本接近于真实样本,而判别器识别不出真样本和伪样本。具体实施时,在训练效果稳定以后,比如判别器的识别准确率变化波动处于-3%~3%的范围时,可执行步骤S102。
步骤S102、通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本。
在一可选实施方式中,可将获取的无标记样本和通过步骤S101训练后的生成器所生成的伪样本输入判别器,以进行机器打标(即通过判别器对输入的无标记样本和伪样本的标签进行预测)。其中,无标记样本可理解为获取的不带标签的真实样本。以商品评价为例,判别器预测的标签有三类:低质量评价(可以用0表示)、高质量评价(可以用-1表示)、假数据(或称为“GAN中的生成器生成的数据”,其可以用-2表示)。其中,预测标签为低质量评价和高质量评价可统一看作是预测为真样本。在进行机器打标之后,可将预测标签为低质量评价、且标签概率值大于预设阈值(比如0.8)的机器打标数据添加至训练样本,并将预测标签为高质量评价、且标签概率值大于预设阈值(比如0.8)的机器打标数据添加至训练样本,从而得到了新的训练样本。
在另一可选实施方式中,也可将获取的无标记样本或通过步骤S101训练后的生成器所生成的伪样本输入判别器,以进行机器打标。然后,将判别器预测为真样本且标签概率值大于预设阈值的机器打标数据添加至训练样本,以得到新的训练样本。
步骤S103、基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
在本发明实施例中,通过以上步骤能够减少模型训练对人工标注数据(尤其是专家标注数据)的依赖,并且能够让模型根据训练效果自我优化。
图2是根据本发明另一实施例的训练模型的方法的主要步骤示意图。如图2所示,本发明实施例的训练模型的方法包括:
步骤S201、将人工标记样本分割成训练样本和测试样本。
具体实施时,所述人工标记样本可为专家提供的包含正负样本的数据。以商品评价为例,可由专家对商品评价进行人工打标,比如将商品评价人工标记为高质量评价(正样本)和低质量评价(负样本),并且不同类别的标记样本需要满足一定数量。在得到人工标记样本之后,可按照80%、20%的比例将其切分成训练样本和测试样本。另外,也可按照70%、30%的比例将人工标记样本切分成训练样本和测试样本。
步骤S202、基于训练样本对GAN模型中的生成器和判别器进行训练。
GAN(生成式对抗网络)模型是一种深度学习模型,其包括生成器和判别器。在GAN模型的训练过程中,生成器主要用来学习输入的训练样本(比如文本、图像数据)的真实分布,从而让自身生成的伪样本更加贴近真实的训练样本,而判别器则努力地识别输入的样本的真假。这个过程相当于一个二人博弈的过程,随着时间的推移,生成器和判别器不断地进行对抗式训练,最终两个网络达到了一个动态平衡:生成器生成的伪样本接近于真实样本,而判别器识别不出真样本和伪样本。具体实施时,在训练效果稳定以后,比如判别器的识别准确率变化波动处于-3%~3%的范围时,可执行步骤S203。
在一可选实施方式中,所述判别器可选用随机森林算法。随机森林算法是一种bagging算法,当训练样本不够时能够避免过拟合,比较适合这种场景。
步骤S203、通过所述判别器对测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第一取值。
其中,所述预测效果评估参数可以为AUC(Area Under Curve,AUC的值为ROC曲线下方的面积)、ROC(Receiver Operating Characteristic Curve接收者操作特征曲线,其横坐标表示假正类率,纵坐标表示真正类率)和/或准确率(表示在预测的正例中确实为正例的比例)等参数。
步骤S204、通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本。
在一可选实施方式中,可将获取的无标记样本和通过步骤S202训练后的生成器所生成的伪样本输入判别器,以进行机器打标(即通过判别器对输入的无标记样本和伪样本的标签进行预测)。以商品评价为例,判别器预测的标签有三类:低质量评价(可以用0表示)、高质量评价(可以用-1表示)、假数据(或称为GAN中的生成器生成的数据,其可以用-2表示)。其中,预测标签为低质量评价和高质量评价可统一看作是预测为真样本。在进行机器打标之后,可将预测标签为低质量评价、且标签概率值大于预设阈值(比如0.8)的机器打标数据添加至训练样本,并将预测标签为高质量评价、且标签概率值大于预设阈值(比如0.8)的机器打标数据添加至训练样本,从而得到了新的训练样本。
步骤S205、基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
在通过步骤S204得到新的训练样本之后,可基于该新的训练样本对GAN模型进行再训练,然后执行步骤S206。
步骤S206、通过再训练后的判别器对所述测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第二取值。
示例性地,假设在步骤S203中选取的预测效果评估参数为AUC,则在步骤S206中计算AUC的第二取值;假设在步骤S203中选取的预测效果评估参数为准确率,则在步骤S206中计算准确率的第二取值。
步骤S207、判断第二取值是否大于第一取值。若是,则再次执行步骤S204;若否,则执行步骤S208。
例如,在步骤S203中选取的预测效果评估参数为AUC,则将AUC的第一取值与第二取值进行比较。如果第二取值大于第一取值,表示判别器的预测效果有所提升,说明通过步骤S204追加的样本能够提高模型的训练效果,无需再次追加人工标记样本模型即可自我迭代更新,从而再次执行步骤S204即可;如果第二取值小于或等于第一取值,则说明需要再次追加人工标记样本,从而执行步骤S208。
步骤S208、向所述训练样本中追加人工标记样本。
在本发明实施例中,通过将无标记样本和GAN生成的伪样本输入判别器进行机器打标,能够将高置信度的机器打标数据添加至训练样本中,从而丰富了训练样本,减少了对人工标注数据的强依赖,实现了一种半监督学习的目的;通过计算判别器的预测效果评估参数的第一取值、第二取值,并将第一取值与第二取值进行比较,能够判断判别器的预测效果是否提升,进而判断通过步骤S204追加的样本是否能够提高模型的训练效果,从而有助于实现模型自我更新迭代的目的。
图3是GAN模型训练的原理示意图。如图3所示,GAN模型包括两部分,即生成器和判别器。GAN模型在无监督学习领域、生成领域、半监督学习领域以及强化学习领域都有广泛的应用。在将GAN模型应用在半监督学习领域时,生成器不做改变,仍然负责输入随机噪声,输出生成的伪样本。而判别器不再是一个简单的真假分类器。假设输入数据有K类,判别器就是K+1的分类器,多出的那一类是判别输入数据是否为生成器生成的数据(即伪样本)。以商品评价为例,判别器预测的标签有三类:低质量评价(可以用0表示)、高质量评价(可以用-1表示)、假数据(或称为GAN中的生成器生成的数据,其可以用-2表示)。
图4是根据本发明一个实施例的训练模型的装置的主要模块示意图。如图4所示,本发明实施例的训练模型的装置400包括:训练模块401、样本添加模块402。
训练模块401,用于基于训练样本对GAN模型中的生成器和判别器进行训练。其中,所述训练样本为人工标记样本。具体实施时,可由专家提供包含正负样本的数据。以商品评价为例,可由专家对商品评价进行人工打标,比如将商品评价人工标记为高质量评价(正样本)和低质量评价(负样本)。在得到人工标记样本之后,再将其分割成训练样本和测试样本。
GAN(生成式对抗网络)模型是一种深度学习模型,其包括生成器和判别器。在GAN模型的训练过程中,生成器主要用来学习输入的训练样本(比如文本、图像数据)的真实分布,从而让自身生成的伪样本更加贴近真实的训练样本,而判别器则努力地识别输入的样本的真假。这个过程相当于一个二人博弈的过程,随着时间的推移,生成器和判别器不断地进行对抗式训练,最终两个网络达到了一个动态平衡:生成器生成的伪样本接近于真实样本,而判别器识别不出真样本和伪样本。具体实施时,在训练效果稳定以后,比如判别器的识别准确率变化波动处于-3%~3%的范围时,可通过样本添加模块向训练样本中添加高置信度的机器打标数据。
样本添加模块402,用于通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本。
在一可选实施方式中,可将获取的无标记样本和训练后的生成器所生成的伪样本输入判别器,以进行机器打标(即通过判别器对输入的无标记样本和伪样本的标签进行预测)。其中,无标记样本可理解为获取的不带标签的真实样本。以商品评价为例,判别器预测的标签有三类:低质量评价(可以用0表示)、高质量评价(可以用-1表示)、假数据(或称为GAN中的生成器生成的数据,其可以用-2表示)。其中,预测标签为低质量评价和高质量评价可统一看作是预测为真样本。在进行机器打标之后,可将预测标签为低质量评价、且标签概率值大于预设阈值(比如0.8)的机器打标数据添加至训练样本,并将预测标签为高质量评价、且标签概率值大于预设阈值(比如0.8)的机器打标数据添加至训练样本,从而得到了新的训练样本。
在另一可选实施方式中,也可将获取的无标记样本或训练后的生成器所生成的伪样本输入判别器,以进行机器打标。然后,将判别器预测为真样本且标签概率值大于预设阈值的机器打标数据添加至训练样本,以得到新的训练样本。
训练模块401,还用于基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
在本发明实施例中,通过以上装置能够丰富训练样本,减少模型训练对人工标注数据(尤其是专家标注数据)的依赖,并且能够让模型根据训练效果自我优化。
图5是根据本发明另一实施例的训练模型的装置的主要模块示意图。如图5所示,本发明实施例的训练模型的装置500包括:训练模块501、样本添加模块502、评估模块503。
训练模块501,用于基于训练样本对GAN模型中的生成器和判别器进行训练。
其中,所述训练样本为人工标记样本。具体实施时,所述人工标记样本可为专家提供的包含正负样本的数据。以商品评价为例,可由专家对商品评价进行人工打标,比如将商品评价人工标记为高质量评价(正样本)和低质量评价(负样本),并且不同类别的标记样本需要满足一定数量。在得到人工标记样本之后,可按照80%、20%的比例将其切分成训练样本和测试样本。另外,也可按照70%、30%的比例将人工标记样本切分成训练样本和测试样本。
GAN(生成式对抗网络)模型是一种深度学习模型,其包括生成器和判别器。在GAN模型的训练过程中,生成器主要用来学习输入的训练样本(比如文本、图像数据)的真实分布,从而让自身生成的伪样本更加贴近真实的训练样本,而判别器则努力地识别输入的样本的真假。这个过程相当于一个二人博弈的过程,随着时间的推移,生成器和判别器不断地进行对抗式训练,最终两个网络达到了一个动态平衡:生成器生成的伪样本接近于真实样本,而判别器识别不出真样本和伪样本。具体实施时,在训练效果稳定以后,比如判别器的识别准确率变化波动处于-3%~3%的范围时,可执行步骤S203。
在一可选实施方式中,所述判别器可选用随机森林算法。随机森林算法是一种bagging算法,当训练样本不够时能够避免过拟合,比较适合这种场景。
评估模块503,用于通过所述判别器对测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第一取值。
其中,所述预测效果评估参数可以为AUC(Area Under Curve,AUC的值为ROC曲线下方的面积)、ROC(Receiver Operating Characteristic Curve接收者操作特征曲线,其横坐标表示假正类率,纵坐标表示真正类率)和/或准确率(表示在预测的正例中确实为正例的比例)等参数。
样本添加模块502,用于通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本。
在一可选实施方式中,可将获取的无标记样本和训练后的生成器所生成的伪样本输入判别器,以进行机器打标(即通过判别器对输入的无标记样本和伪样本的标签进行预测)。以商品评价为例,判别器预测的标签有三类:低质量评价(可以用0表示)、高质量评价(可以用-1表示)、假数据(或称为GAN中的生成器生成的数据,其可以用-2表示)。其中,预测标签为低质量评价和高质量评价可统一看作是预测为真样本。在进行机器打标之后,可将预测标签为低质量评价、且标签概率值大于预设阈值(比如0.8)的机器打标数据添加至训练样本,并将预测标签为高质量评价、且标签概率值大于预设阈值(比如0.8)的机器打标数据添加至训练样本,从而得到了新的训练样本。
训练模块501,用于基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
评估模块503,用于通过再训练后的判别器对所述测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第二取值。
示例性地,假设在计算第一取值时选取的预测效果评估参数为AUC,则评估模块503计算AUC的第二取值;假设在计算第一取值时选取的预测效果评估参数为准确率,则评估模块503计算准确率的第二取值。
样本添加模块502,还用于在第二取值高于第一取值的情况下,再次执行所述通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本的操作。样本添加模块502,还用于在第二取值低于或等于第一取值的情况下,向所述训练样本中追加人工标记样本。
例如,选取的预测效果评估参数为AUC,则将AUC的第一取值与第二取值进行比较。如果第二取值大于第一取值,表示判别器的预测效果有所提升,说明追加的高置信度的机器打标数据能够提高模型的训练效果,无需再次追加人工标记样本模型即可自我迭代更新;如果第二取值小于或等于第一取值,则说明需要再次追加人工标记样本,从而通过样本添加模块向训练样本中追加人工标记样本。
在本发明实施例中,通过将无标记样本和GAN生成的伪样本输入判别器进行机器打标,能够将高置信度的机器打标数据添加至训练样本中,从而丰富了训练样本,减少了对人工标注数据的强依赖,实现了一种半监督学习的目的;通过评估模块计算判别器的预测效果评估参数的第一取值、第二取值,并将第一取值与第二取值进行比较,能够判断判别器的预测效果是否提升,进而判断追加的高置信度的机器打标数据是否能够提高模型的训练效果,从而有助于实现模型自我更新迭代的目的。
图6示出了可以应用本发明实施例的训练模型的方法或训练模型的装置的示例性系统架构600。
如图6所示,系统架构600可以包括终端设备601、602、603,网络604和服务器605。网络604用以在终端设备601、602、603和服务器605之间提供通信链路的介质。网络604可以包括各种连接类型,例如有线、无线通信链路或者光纤电缆等等。
用户可以使用终端设备601、602、603通过网络604与服务器605交互,以接收或发送消息等。终端设备601、602、603上可以安装有各种通讯客户端应用,例如购物类应用、网页浏览器应用、搜索类应用、即时通信工具、邮箱客户端、社交平台软件等。
终端设备601、602、603可以是具有显示屏并且支持网页浏览的各种电子设备,包括但不限于智能手机、平板电脑、膝上型便携计算机和台式计算机等等。
服务器605可以是提供各种服务的服务器,例如对用户利用终端设备601、602、603所发出的训练请求提供支持的后台管理服务器。后台管理服务器可以对接收到的训练请求等数据进行分析等处理,并将处理结果(例如训练结果)反馈给终端设备。
需要说明的是,本发明实施例所提供的训练模型的方法一般由服务器605执行,相应地,训练模型的装置一般设置于服务器605中。
应该理解,图6中的终端设备、网络和服务器的数目仅仅是示意性的。根据实现需要,可以具有任意数目的终端设备、网络和服务器。
图7示出了适于用来实现本发明实施例的电子设备的计算机系统700的结构示意图。图7示出的电子设备仅仅是一个示例,不应对本发明实施例的功能和使用范围带来任何限制。
如图7所示,计算机系统700包括中央处理单元(CPU)701,其可以根据存储在只读存储器(ROM)702中的程序或者从存储部分708加载到随机访问存储器(RAM)703中的程序而执行各种适当的动作和处理。在RAM 703中,还存储有系统700操作所需的各种程序和数据。CPU 701、ROM 702以及RAM 703通过总线704彼此相连。输入/输出(I/O)接口705也连接至总线704。
以下部件连接至I/O接口705:包括键盘、鼠标等的输入部分706;包括诸如阴极射线管(CRT)、液晶显示器(LCD)等以及扬声器等的输出部分707;包括硬盘等的存储部分708;以及包括诸如LAN卡、调制解调器等的网络接口卡的通信部分709。通信部分709经由诸如因特网的网络执行通信处理。驱动器710也根据需要连接至I/O接口705。可拆卸介质711,诸如磁盘、光盘、磁光盘、半导体存储器等等,根据需要安装在驱动器710上,以便于从其上读出的计算机程序根据需要被安装入存储部分708。
特别地,根据本发明公开的实施例,上文参考流程图描述的过程可以被实现为计算机软件程序。例如,本发明公开的实施例包括一种计算机程序产品,其包括承载在计算机可读介质上的计算机程序,该计算机程序包含用于执行流程图所示的方法的程序代码。在这样的实施例中,该计算机程序可以通过通信部分709从网络上被下载和安装,和/或从可拆卸介质711被安装。在该计算机程序被中央处理单元(CPU)701执行时,执行本发明的系统中限定的上述功能。
需要说明的是,本发明所示的计算机可读介质可以是计算机可读信号介质或者计算机可读存储介质或者是上述两者的任意组合。计算机可读存储介质例如可以是——但不限于——电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子可以包括但不限于:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机访问存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本发明中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。而在本发明中,计算机可读的信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。计算机可读的信号介质还可以是计算机可读存储介质以外的任何计算机可读介质,该计算机可读介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。计算机可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于:无线、电线、光缆、RF等等,或者上述的任意合适的组合。
附图中的流程图和框图,图示了按照本发明各种实施例的系统、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段、或代码的一部分,上述模块、程序段、或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个接连地表示的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图或流程图中的每个方框、以及框图或流程图中的方框的组合,可以用执行规定的功能或操作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
描述于本发明实施例中所涉及到的模块可以通过软件的方式实现,也可以通过硬件的方式来实现。所描述的模块也可以设置在处理器中,例如,可以描述为:一种处理器包括训练模块、样本添加模块。其中,这些模块的名称在某种情况下并不构成对该模块本身的限定,例如,训练模块还可以被描述为“对GAN模型中的生成器和判别器进行训练的模块”。
作为另一方面,本发明还提供了一种计算机可读介质,该计算机可读介质可以是上述实施例中描述的设备中所包含的;也可以是单独存在,而未装配入该设备中。上述计算机可读介质承载有一个或者多个程序,当上述一个或者多个程序被一个该设备执行时,使得该设备执行以下流程:基于训练样本对GAN模型中的生成器和判别器进行训练;其中,所述训练样本为人工标记样本;通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本;基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
上述具体实施方式,并不构成对本发明保护范围的限制。本领域技术人员应该明白的是,取决于设计要求和其他因素,可以发生各种各样的修改、组合、子组合和替代。任何在本发明的精神和原则之内所作的修改、等同替换和改进等,均应包含在本发明保护范围之内。
Claims (10)
1.一种训练模型的方法,其特征在于,所述方法包括:
基于训练样本对GAN模型中的生成器和判别器进行训练;其中,所述训练样本为人工标记样本;
通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本;
基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
2.根据权利要求1所述的方法,其特征在于,所述方法还包括:
在执行所述基于训练样本对GAN模型中的生成器和判别器进行训练的步骤之后,通过所述判别器对测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第一取值;其中,所述测试样本为人工标记样本;
在执行所述基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练的步骤之后,通过所述判别器对所述测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第二取值;以及,在第二取值高于第一取值的情况下,再次执行所述通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本的步骤。
3.根据权利要求2所述的方法,其特征在于,所述方法还包括:
在第二取值低于或等于第一取值的情况下,向所述训练样本中追加人工标记样本。
4.根据权利要求2所述的方法,其特征在于,所述预测效果评估参数包括以下至少一项:准确率、ROC、AUC。
5.根据权利要求1所述的方法,其特征在于,所述判别器采用随机森林算法。
6.一种训练模型的装置,其特征在于,所述装置包括:
训练模块,用于基于训练样本对GAN模型中的生成器和判别器进行训练;所述训练样本为人工标记样本;
样本添加模块,用于通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本,以得到新的训练样本;
所述训练模块,还用于基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练。
7.根据权利要求6所述的装置,其特征在于,所述装置还包括:
评估模块,用于在所述训练模块基于训练样本对GAN模型中的生成器和判别器进行训练之后,通过所述判别器对测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第一取值;其中,所述测试样本为人工标记样本;
所述评估模块,还用于在所述训练模块基于所述新的训练样本对GAN模型中的生成器和判别器进行再训练之后,通过所述判别器对所述测试样本进行机器打标,以得到所述判别器的预测效果评估参数的第二取值;
所述样本添加模块,还用于在第二取值高于第一取值的情况下,再次执行所述通过所述判别器对无标记样本和/或由所述生成器生成的伪样本进行机器打标,然后将预测为真样本且标签概率值大于预设阈值的机器打标数据添加至所述训练样本的步骤。
8.根据权利要求7所述的装置,其特征在于,所述样本添加模块,还用于在第二取值低于或等于第一取值的情况下,向所述训练样本中追加人工标记样本。
9.一种电子设备,其特征在于,包括:
一个或多个处理器;
存储装置,用于存储一个或多个程序,
当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如权利要求1至5中任一所述的方法。
10.一种计算机可读介质,其上存储有计算机程序,其特征在于,所述程序被处理器执行时实现如权利要求1至5中任一所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201810695632.2A CN110659657B (zh) | 2018-06-29 | 2018-06-29 | 训练模型的方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201810695632.2A CN110659657B (zh) | 2018-06-29 | 2018-06-29 | 训练模型的方法和装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110659657A true CN110659657A (zh) | 2020-01-07 |
CN110659657B CN110659657B (zh) | 2024-05-24 |
Family
ID=69027537
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201810695632.2A Active CN110659657B (zh) | 2018-06-29 | 2018-06-29 | 训练模型的方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110659657B (zh) |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111582647A (zh) * | 2020-04-09 | 2020-08-25 | 上海淇毓信息科技有限公司 | 用户数据处理方法、装置及电子设备 |
CN111783016A (zh) * | 2020-07-03 | 2020-10-16 | 支付宝(杭州)信息技术有限公司 | 一种网站分类方法、装置及设备 |
CN112183088A (zh) * | 2020-09-28 | 2021-01-05 | 云知声智能科技股份有限公司 | 词语层级确定的方法、模型构建方法、装置及设备 |
CN112420205A (zh) * | 2020-12-08 | 2021-02-26 | 医惠科技有限公司 | 实体识别模型生成方法、装置及计算机可读存储介质 |
CN112581472A (zh) * | 2021-01-26 | 2021-03-30 | 中国人民解放军国防科技大学 | 一种面向人机交互的目标表面缺陷检测方法 |
CN112988854A (zh) * | 2021-05-20 | 2021-06-18 | 创新奇智(成都)科技有限公司 | 一种申诉数据挖掘方法、装置、电子设备及存储介质 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20080103996A1 (en) * | 2006-10-31 | 2008-05-01 | George Forman | Retraining a machine-learning classifier using re-labeled training samples |
CN104751182A (zh) * | 2015-04-02 | 2015-07-01 | 中国人民解放军空军工程大学 | 基于ddag的svm多类分类主动学习算法 |
CN104966097A (zh) * | 2015-06-12 | 2015-10-07 | 成都数联铭品科技有限公司 | 一种基于深度学习的复杂文字识别方法 |
CN107622056A (zh) * | 2016-07-13 | 2018-01-23 | 百度在线网络技术(北京)有限公司 | 训练样本的生成方法和装置 |
KR101843066B1 (ko) * | 2017-08-23 | 2018-05-15 | 주식회사 뷰노 | 기계 학습에 있어서 데이터 확대를 이용하여 데이터의 분류를 수행하는 방법 및 이를 이용한 장치 |
CN108121975A (zh) * | 2018-01-04 | 2018-06-05 | 中科汇通投资控股有限公司 | 一种联合原始数据和生成数据的人脸识别方法 |
CN108171770A (zh) * | 2018-01-18 | 2018-06-15 | 中科视拓(北京)科技有限公司 | 一种基于生成式对抗网络的人脸表情编辑方法 |
-
2018
- 2018-06-29 CN CN201810695632.2A patent/CN110659657B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20080103996A1 (en) * | 2006-10-31 | 2008-05-01 | George Forman | Retraining a machine-learning classifier using re-labeled training samples |
CN104751182A (zh) * | 2015-04-02 | 2015-07-01 | 中国人民解放军空军工程大学 | 基于ddag的svm多类分类主动学习算法 |
CN104966097A (zh) * | 2015-06-12 | 2015-10-07 | 成都数联铭品科技有限公司 | 一种基于深度学习的复杂文字识别方法 |
CN107622056A (zh) * | 2016-07-13 | 2018-01-23 | 百度在线网络技术(北京)有限公司 | 训练样本的生成方法和装置 |
KR101843066B1 (ko) * | 2017-08-23 | 2018-05-15 | 주식회사 뷰노 | 기계 학습에 있어서 데이터 확대를 이용하여 데이터의 분류를 수행하는 방법 및 이를 이용한 장치 |
CN108121975A (zh) * | 2018-01-04 | 2018-06-05 | 中科汇通投资控股有限公司 | 一种联合原始数据和生成数据的人脸识别方法 |
CN108171770A (zh) * | 2018-01-18 | 2018-06-15 | 中科视拓(北京)科技有限公司 | 一种基于生成式对抗网络的人脸表情编辑方法 |
Non-Patent Citations (3)
Title |
---|
刘海东 等: "基于生成对抗网络的乳腺癌病理图像可疑区域标记", 科研信息化技术与应用, vol. 8, no. 6, pages 52 - 64 * |
杜秋平 等: "基于图像云模型语义标注的条件生成对抗网络", 模式识别与人工智能, vol. 31, no. 04, pages 379 - 388 * |
蒋芸 等: "基于条件生成对抗网络的咬翼片图像分割", 计算机工程, vol. 45, no. 04, pages 223 - 227 * |
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111582647A (zh) * | 2020-04-09 | 2020-08-25 | 上海淇毓信息科技有限公司 | 用户数据处理方法、装置及电子设备 |
CN111783016A (zh) * | 2020-07-03 | 2020-10-16 | 支付宝(杭州)信息技术有限公司 | 一种网站分类方法、装置及设备 |
CN111783016B (zh) * | 2020-07-03 | 2021-05-04 | 支付宝(杭州)信息技术有限公司 | 一种网站分类方法、装置及设备 |
CN112183088A (zh) * | 2020-09-28 | 2021-01-05 | 云知声智能科技股份有限公司 | 词语层级确定的方法、模型构建方法、装置及设备 |
CN112183088B (zh) * | 2020-09-28 | 2023-11-21 | 云知声智能科技股份有限公司 | 词语层级确定的方法、模型构建方法、装置及设备 |
CN112420205A (zh) * | 2020-12-08 | 2021-02-26 | 医惠科技有限公司 | 实体识别模型生成方法、装置及计算机可读存储介质 |
CN112581472A (zh) * | 2021-01-26 | 2021-03-30 | 中国人民解放军国防科技大学 | 一种面向人机交互的目标表面缺陷检测方法 |
CN112988854A (zh) * | 2021-05-20 | 2021-06-18 | 创新奇智(成都)科技有限公司 | 一种申诉数据挖掘方法、装置、电子设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN110659657B (zh) | 2024-05-24 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109460513B (zh) | 用于生成点击率预测模型的方法和装置 | |
CN110659657B (zh) | 训练模型的方法和装置 | |
CN111428010B (zh) | 人机智能问答的方法和装置 | |
WO2018153806A1 (en) | Training machine learning models | |
CN109145828B (zh) | 用于生成视频类别检测模型的方法和装置 | |
CN110929799B (zh) | 用于检测异常用户的方法、电子设备和计算机可读介质 | |
CN112650841A (zh) | 信息处理方法、装置和电子设备 | |
US20220036178A1 (en) | Dynamic gradient aggregation for training neural networks | |
KR102614912B1 (ko) | 딥러닝 기반 특허 잠재가치 평가 장치 및 그 방법 | |
CN112966701A (zh) | 目标分类的方法和装置 | |
CN113743971A (zh) | 一种数据处理方法和装置 | |
CN113627536A (zh) | 模型训练、视频分类方法,装置,设备以及存储介质 | |
CN110909768B (zh) | 一种标注数据获取方法及装置 | |
CN113392920B (zh) | 生成作弊预测模型的方法、装置、设备、介质及程序产品 | |
CN113051911B (zh) | 提取敏感词的方法、装置、设备、介质及程序产品 | |
CN117235371A (zh) | 视频推荐方法、模型训练方法及装置 | |
CN112633004A (zh) | 文本标点符号删除方法、装置、电子设备和存储介质 | |
CN113111167A (zh) | 基于深度学习模型的接处警文本车辆型号提取方法和装置 | |
CN112000872A (zh) | 基于用户向量的推荐方法、模型的训练方法及装置 | |
CN111858916A (zh) | 用于聚类句子的方法和装置 | |
CN111784377B (zh) | 用于生成信息的方法和装置 | |
CN110633476B (zh) | 用于获取知识标注信息的方法及装置 | |
CN113886543A (zh) | 生成意图识别模型的方法、装置、介质及程序产品 | |
CN113111165A (zh) | 基于深度学习模型的接警警情类别确定方法和装置 | |
CN113111234A (zh) | 基于正则表达式的处警警情类别确定方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |