CN114528896A - 模型训练、数据增强方法、装置、电子设备及存储介质 - Google Patents
模型训练、数据增强方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN114528896A CN114528896A CN202011320953.8A CN202011320953A CN114528896A CN 114528896 A CN114528896 A CN 114528896A CN 202011320953 A CN202011320953 A CN 202011320953A CN 114528896 A CN114528896 A CN 114528896A
- Authority
- CN
- China
- Prior art keywords
- data
- sample data
- network model
- confrontation network
- discriminator
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 79
- 238000012549 training Methods 0.000 title claims abstract description 68
- 239000000523 sample Substances 0.000 claims abstract description 117
- 239000013074 reference sample Substances 0.000 claims abstract description 44
- 230000006870 function Effects 0.000 claims description 40
- 238000004891 communication Methods 0.000 claims description 19
- 238000010606 normalization Methods 0.000 claims description 8
- 238000004364 calculation method Methods 0.000 claims description 4
- 238000004590 computer program Methods 0.000 claims description 3
- 238000010586 diagram Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 4
- 238000005457 optimization Methods 0.000 description 4
- 230000008569 process Effects 0.000 description 4
- 230000009471 action Effects 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 2
- 238000005070 sampling Methods 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000004880 explosion Methods 0.000 description 1
- 230000001771 impaired effect Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000003278 mimic effect Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000002093 peripheral effect Effects 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000000926 separation method Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Probability & Statistics with Applications (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及一种模型训练、数据增强方法、装置、电子设备及存储介质,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,所述方法包括:所述生成器生成参考样本数据;第一判别器计算参考样本数据与预设负样本数据之间的第一距离;第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;基于所述第一距离和所述第二距离确定目标函数;利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。本发明实施例能够使得数据样本标签不平衡。
Description
技术领域
本申请涉及计算机技术领域,尤其涉及一种模型训练、数据增强方法、装置、电子设备及存储介质。
背景技术
随着数据采集技术的不断进步,越来愈多的数据正在被收集并广泛应用于商业分析,金融服务和医疗教育等各个方面。
但是,由于数据本身的不平衡性和采集手段的限制,相当多的数据存在没有标签或者标签不平衡的情况,导致模型结果不理想,甚至输出错误的结果,这对我们当前的数据预处理技术带来了极大的挑战。具体来讲,数据样本标签不平衡是指在不同标签的数据源中,某一些标签的数据占绝大部分,而另外一些标签的数据只占很少一部分。例如在二分类预测问题中,标签为“1”的数据占总量的99%,而标签为“0”的数据只占1%。这种数据常常会损害模型的效果,使一个二分类模型无法获得较好的预测结果。
发明内容
为了解决上述技术问题或者至少部分地解决上述技术问题,本申请提供了一种模型训练、数据增强方法、装置、电子设备及存储介质。
第一方面,本申请提供了一种模型训练方法,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,所述方法包括:
所述生成器生成参考样本数据;
第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
基于所述第一距离和所述第二距离确定目标函数;
利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
可选的,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
可选的,所述利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型,包括:
利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
可选的,所述目标函数为:
其中,posData表示正类数据,negData表示负类数据,allData表示生成的负类数据和原有负类数据的并集。D1表示第一判别器参数,D2表示第二判别器参数,G表示生成器参数。
可选的,所述第一判别器和所述第二判别器的结构相同,所述第一判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层、leaky-ReLU层和sigmoid层。
可选的,所述生成器包括多个级联的生成单元,每个生成单元包括级联的全连接层、标准化层和leaky-ReLU层。
第二方面,本申请提供了一种数据增强方法,包括:
利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如第一方面任一所述的模型训练方法训练得到的;
将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
第三方面,本申请提供了一种模型训练装置,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,所述装置包括:
生成模块,用于所述生成器生成参考样本数据;
第一计算模块,用于第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
第二计算模块,用于第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
选择模块,用于基于所述第一距离和所述第二距离确定目标函数;
训练模块,用于利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
可选地,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
可选地,所述训练模块,还用于:
利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
可选地,所述目标函数为:
其中,posData表示正类数据,negData表示负类数据,allData表示生成的负类数据和原有负类数据的并集。D1表示第一判别器参数,D2表示第二判别器参数,G表示生成器参数。
可选地,所述第一判别器和所述第二判别器的结构相同,所述第一判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层、leaky-ReLU层和sigmoid层。
可选地,所述生成器包括多个级联的生成单元,每个生成单元包括级联的全连接层、标准化层和leaky-ReLU层。
第四方面,本申请提供了一种数据增强装置,包括:
生成模块,用于利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如权利要求8所述的模型训练方法训练得到的;
添加模块,用于将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
第五方面,本申请提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现第一方面任一所述的模型训练方法或第二方面所述的数据增强方法。
第六方面,本申请提供了一种计算机可读存储介质,所述计算机可读存储介质上存储有模型训练方法的程序或者数据增强方法的程序,所述模型训练方法的程序被处理器执行时实现第一方面任一所述的模型训练方法的步骤,所述数据增强方法的程序被处理器执行时实现第二方面所述的数据增强方法的步骤。
本申请实施例提供的上述技术方案与现有技术相比具有如下优点:
本申请实施例提供的该方法,本发明实施例通过生成器生成参考样本数据,第一判别器计算参考样本数据与预设负样本数据之间的第一距离,第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离,再基于所述第一距离和所述第二距离确定目标函数,最后可以利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
本发明实施例通过生成器生成参考样本数据,基于第一距离和第二距离确定目标函数,利用所述目标函数训练所述生成对抗网络模型,可以使训练完成的生成对抗网络模型的输出数据满足预设样本平衡条件,对较少的那一类样本生成额外的数据,即生成的输出数据可以使两类样本更加平衡,由于是生成额外的数据,所以不会对数据量造成损失,使得数据样本标签不平衡。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本发明的实施例,并与说明书一起用于解释本发明的原理。
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,对于本领域普通技术人员而言,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例提供的一种生成对抗网络模型的原理示意图;
图2为本申请实施例提供的一种模型训练方法的一种流程图;
图3为图1中步骤S105的流程图;
图4为本申请实施例提供的一种模型训练方法的另一种流程图;
图5为本申请实施例提供的一种模型训练装置的结构图;
图6为本申请实施例提供的另一种模型训练装置的结构图;
图7为本申请实施例提供的一种电子设备的结构图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请的一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本申请保护的范围。
在实现本发明的过程中,发明人发现,现有的技术方案往往通过上采样,下采样和对样本分配权重的方式以解决数据样本标签不平衡的问题。这些方法往往存在一些缺陷。其一,这些方法有时难以取得较好的效果。以下采样为例,这种方法通过对标签较多的那一类数据进行下采样,使得两种(或多种)标签的数据具有相似的数量。但是,对于不平衡较为严重的情况来说,这种方法会大大减少可以使用的数据量,损害了模型的效果。其二,有些方法对于模型的依赖较为严重,效果可能随着模型的变换而变化。比如对样本分配权重的方法就要求模型必须可以处理带权重的样本。另外对于样本权重的选取也增加了应用该方法的难度。为此,本发明实施例提供了一种模型训练、数据增强方法、装置、电子设备及存储介质,所述模型训练方法用于训练生成对抗网络模型,生成对抗网络:是机器学习中非监督式学习的一种方法,通过让两个神经网络相互博弈的方式进行学习。生成对抗网络由一个生成网络与一个判别网络组成。生成网络从潜在空间(latent space)中随机取样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入则为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能分辨出来。而生成网络则要尽可能地欺骗判别网络。两个网络相互对抗、不断调整参数,最终目的是使判别网络无法判断生成网络的输出结果是否真实。
不同于一般的生成对抗网络模型,在本发明实施例中,同时利用正样本和负样本生成负样本数据对生成对抗网络模型进行训练。本发明实施例所依据的原理是:减小生成的数据与负样本之间的差异,并增大生成数据与正样本的差异。通过这种方法生成的负样本能够保持与真实负样本分布相近但是与正样本保持足够的分离间隔。使得重组后的数据能够使分类器更好地找到正负类的分离面。
在本发明实施例中,如图1所示,生成对抗网络模型包括:生成器(generator)和两个判别器(discriminator),也就是说,模型训练方法用于训练生成器和两个判别器。其中,所述生成器的输出作为两个所述判别器的输入,假设两个判别器分别为第一判别器和第二判别器,生成器用于将输入的随机噪声数据转化为和真实负样本分布相近的数据,从而生成参考样本数据(负样本数据),达到数据增强的目的;
将参考样本数据和预设负样本数据输入第一判别器中,第一判别器判别参考样本数据和预设负样本数据之间的差距,也即第一判别器用于判断参考样本数据和预设负样本数据是否属于同一类;
将参考样本数据和预设负样本数据合并,得到负类数据,将负类数据和预设正样本数据输入第二判别器,第二判别器判别负类数据和预设正样本数据之间的差距,也就是说,第二判别器用于判断负类数据和预设正样本数据是否为同一类。
如图2所示,所述模型训练方法可以包括以下步骤:
步骤S101,所述生成器生成参考样本数据;
在本发明实施例中,所述生成器包括多个级联的生成单元,每个生成单元包括级联的全连接层、标准化层和leaky-ReLU层,其中,标准化层可以指batch-normalization算法层,batch-normalization算法层用于防止梯度爆炸,在本发明实施例中,第一级生成单元中全连接层和leaky-ReLU层的维度均为256,第二级生成单元中全连接层和leaky-ReLU层的维度均为512,第三级生成单元中全连接层和leaky-ReLU层的维度均为1024。
在步骤S101之前,可以获取原始数据集及服从高斯分布的随机噪声数据,原始数据集中包括预设正样本数据和负样本数据。
为表述方便,在本发明实施例中,将标签较少的样本称作负样本数据,将标签较多的样本称作正样本数据,并且令负样本的标签为-1,正样本的标签为1。
在该步骤中,可以将服从高斯分布的随机噪声数据输入至生成器的输入层,随机噪声数据的维度是100维,生成器可以基于随机噪声数据生成参考样本数据。
步骤S102,第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
在本发明实施例中,所述第一判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层和leaky-ReLU层,第一级判别单元中全连接层和leaky-ReLU层的维度均为512,第二级判别单元中全连接层和leaky-ReLU层的维度均为256。
步骤S103,第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
在本发明实施例中,所述第二判别器和所述第一判别器的结构相同,所述第二判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层、leaky-ReLU层和sigmoid层。
步骤S104,基于所述第一距离和所述第二距离确定目标函数;
为了减小参考样本数据与负样本之间的差异,并增大参考样本数据与正样本的差异,也就是说,本发明实施例的目的是使目标样本数据可以使第一分类器产生较大的误差(即:使目标样本数据和预设负样本数据差距较小),而使第二分类器产生较小的误差(即:使目标样本数据和预设正样本数据差距较大)。
也就是说,在本发明实施例中,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
所以在该步骤中,可以基于第一距离和第二距离在参考样本数据中选择满足预设样本平衡条件的目标样本数据,预设样本平衡条件可以指与预设负样本预设负样本数据差距较小,且,和预设正样本数据差距较大。
满足预设样本平衡条件的目标样本数据即参考样本数据中,第一距离较小且第二距离较大的目标样本数据,示例性的,目标样本数据可以指参考样本数据中,第一距离小于预设第一阈值且第二距离大于预设第二阈值的目标样本数据。
步骤S105,利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
在该步骤中,可以将所述预设负样本数据和所述正样本数据输入生成对抗网络模型,基于生成对抗网络模型输出的输出数据与所述目标样本数据之间的差异,不断的调整生成对抗网络模型的模型参数,直至输出数据与所述目标样本数据一致,确定生成对抗网络模型收敛,得到所述生成对抗网络模型,以用于数据增强。
本发明实施例通过生成器生成参考样本数据,第一判别器计算参考样本数据与预设负样本数据之间的第一距离,第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离,再基于所述第一距离和所述第二距离确定目标函数,最后可以利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
本发明实施例通过生成器生成参考样本数据,基于第一距离和第二距离确定目标函数,利用所述目标函数训练所述生成对抗网络模型,可以使训练完成的生成对抗网络模型的输出数据满足预设样本平衡条件,对较少的那一类样本生成额外的数据,即生成的输出数据可以使两类样本更加平衡,由于是生成额外的数据,所以不会对数据量造成损失,使得数据样本标签不平衡。
在本发明的又一实施例中,如图3所示,所述步骤S105可以包括以下步骤:
步骤S301,利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
在本发明实施例中,所述目标函数为:
其中,posData表示正类数据,negData表示负类数据,allData表示生成的负类数据和原有负类数据的并集。D1表示第一判别器参数,D2表示第二判别器参数,G表示生成器参数。
步骤S302,将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
本发明实施例通过目标函数,能够不断的调整模型参数,最终得到生成器参数、第一判别器参数和第二判别器参数,便于使生成对抗网络模型的输出数据满足预设样本平衡条件,对较少的那一类样本生成额外的数据,即生成的输出数据可以使两类样本更加平衡,由于是生成额外的数据,所以不会对数据量造成损失,使得数据样本标签不平衡。
在本发明的又一实施例中,还提供一种数据增强方法,如图4所示,所述方法包括:
步骤S401,利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如前述方法实施例所述的模型训练方法训练得到的;
在该步骤中,生成对抗网络模型的输入数据为服从高斯分布的随机噪声数据,再利用生成对抗网络模型进行数据增强时,生成对抗网络模型的输入数据与训练该生成对抗网络模型时输入至生成器的服从高斯分布的随机噪声数据相同。
第二负样本数据加上预设负样本数据的总数一般应该与预设正样本数据的数量相同。
生成第二负样本数据后,将第二负样本数据对应的数据标签设置为-1(即与预设负样本数据的标签相同)。
步骤S402,将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
在该步骤中,可以将生成的第二负样本数据加入原数据集,并将整个数据集进行随机打乱操作,得到新数据集。
本发明实施例能够生成第二负样本数据,并将生成第二负样本数据加入原始数据集,得到可直接用于训练的新数据集,新数据集对其所运用的模型没有依赖。
在本发明的又一实施例中,还提供一种模型训练装置,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,如图5所示,所述装置包括:
生成模块11,用于所述生成器生成参考样本数据;
第一计算模块12,用于第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
第二计算模块13,用于第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
选择模块14,用于基于所述第一距离和所述第二距离确定目标函数;
训练模块15,用于利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
可选地,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
可选地,所述训练模块,还用于:
利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
可选地,所述目标函数为:
其中,posData表示正类数据,negData表示负类数据,allData表示生成的负类数据和原有负类数据的并集。D1表示第一判别器参数,D2表示第二判别器参数,G表示生成器参数。
可选地,所述第一判别器和所述第二判别器的结构相同,所述第一判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层、leaky-ReLU层和sigmoid层。
可选地,所述生成器包括多个级联的生成单元,每个生成单元包括级联的全连接层、标准化层和leaky-ReLU层。
在本发明的又一实施例中,还提供一种数据增强装置,如图6所示,包括:
生成模块21,用于利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如前述装置实施例所述的模型训练方法训练得到的;
添加模块22,用于将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
在本发明的又一实施例中,还提供一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现前述方法实施例所述的模型训练方法或前述方法实施例所述的数据增强方法。
本发明实施例提供的电子设备,处理器通过执行存储器上所存放的程序实现了本发明实施例通过生成器生成参考样本数据,第一判别器计算参考样本数据与预设负样本数据之间的第一距离,第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离,再基于所述第一距离和所述第二距离确定目标函数,最后可以利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。本发明实施例通过生成器生成参考样本数据,基于第一距离和第二距离在参考样本数据中选择满足预设样本平衡条件的目标样本数据,最后利用目标样本数据、预设负样本数据和正样本数据训练生成对抗网络模型,可以使训练完成的生成对抗网络模型的输出数据满足预设样本平衡条件,对较少的那一类样本生成额外的数据,即生成的输出数据可以使两类样本更加平衡,由于是生成额外的数据,所以不会对数据量造成损失,使得数据样本标签不平衡。
上述电子设备提到的通信总线1140可以是外设部件互连标准(PeripheralComponentInterconnect,简称PCI)总线或扩展工业标准结构(ExtendedIndustryStandardArchitecture,简称EISA)总线等。该通信总线1140可以分为地址总线、数据总线、控制总线等。为便于表示,图7中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口1120用于上述电子设备与其他设备之间的通信。
存储器1130可以包括随机存取存储器(RandomAccessMemory,简称RAM),也可以包括非易失性存储器(non-volatilememory),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器1110可以是通用处理器,包括中央处理器(CentralProcessingUnit,简称CPU)、网络处理器(NetworkProcessor,简称NP)等;还可以是数字信号处理器(DigitalSignalProcessing,简称DSP)、专用集成电路(ApplicationSpecificIntegratedCircuit,简称ASIC)、现场可编程门阵列(Field-ProgrammableGateArray,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
在本发明的又一实施例中,还提供一种计算机可读存储介质,所述计算机可读存储介质上存储有模型训练方法的程序或者数据增强方法的程序,所述模型训练方法的程序被处理器执行时实现前述方法实施例所述的模型训练方法的步骤,所述数据增强方法的程序被处理器执行时实现前述方法实施例所述的数据增强方法的步骤。
需要说明的是,在本文中,诸如“第一”和“第二”等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
以上所述仅是本发明的具体实施方式,使本领域技术人员能够理解或实现本发明。对这些实施例的多种修改对本领域的技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所申请的原理和新颖特点相一致的最宽的范围。
Claims (11)
1.一种模型训练方法,其特征在于,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,所述方法包括:
所述生成器生成参考样本数据;
第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
基于所述第一距离和所述第二距离确定目标函数;
利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
2.根据权利要求1所述的模型训练方法,其特征在于,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
3.根据权利要求1所述的模型训练方法,其特征在于,所述利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型,包括:
利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
5.根据权利要求1所述的模型训练方法,其特征在于,所述第一判别器和所述第二判别器的结构相同,所述第一判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层、leaky-ReLU层和sigmoid层。
6.根据权利要求1所述的模型训练方法,其特征在于,所述生成器包括多个级联的生成单元,每个生成单元包括级联的全连接层、标准化层和leaky-ReLU层。
7.一种数据增强方法,其特征在于,包括:
利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如权利要求1至6任一所述的模型训练方法训练得到的;
将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
8.一种模型训练装置,其特征在于,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,所述装置包括:
生成模块,用于所述生成器生成参考样本数据;
第一计算模块,用于第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
第二计算模块,用于第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
选择模块,用于基于所述第一距离和所述第二距离确定目标函数;
训练模块,用于利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
9.一种数据增强装置,其特征在于,包括:
生成模块,用于利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如权利要求8所述的模型训练方法训练得到的;
添加模块,用于将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
10.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现权利要求1~6任一所述的模型训练方法或权利要求7所述的数据增强方法。
11.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有模型训练方法的程序或者数据增强方法的程序,所述模型训练方法的程序被处理器执行时实现权利要求1-6任一所述的模型训练方法的步骤,所述数据增强方法的程序被处理器执行时实现权利要求7所述的数据增强方法的步骤。
Priority Applications (5)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011320953.8A CN114528896A (zh) | 2020-11-23 | 2020-11-23 | 模型训练、数据增强方法、装置、电子设备及存储介质 |
KR1020237015037A KR20230107558A (ko) | 2020-11-23 | 2021-11-15 | 모델 트레이닝, 데이터 증강 방법, 장치, 전자 기기 및 저장 매체 |
JP2023531631A JP2023550194A (ja) | 2020-11-23 | 2021-11-15 | モデル訓練方法、データ強化方法、装置、電子機器及び記憶媒体 |
US18/254,158 US20240037408A1 (en) | 2020-11-23 | 2021-11-15 | Method and apparatus for model training and data enhancement, electronic device and storage medium |
PCT/CN2021/130667 WO2022105713A1 (zh) | 2020-11-23 | 2021-11-15 | 模型训练、数据增强方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011320953.8A CN114528896A (zh) | 2020-11-23 | 2020-11-23 | 模型训练、数据增强方法、装置、电子设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114528896A true CN114528896A (zh) | 2022-05-24 |
Family
ID=81618498
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011320953.8A Pending CN114528896A (zh) | 2020-11-23 | 2020-11-23 | 模型训练、数据增强方法、装置、电子设备及存储介质 |
Country Status (5)
Country | Link |
---|---|
US (1) | US20240037408A1 (zh) |
JP (1) | JP2023550194A (zh) |
KR (1) | KR20230107558A (zh) |
CN (1) | CN114528896A (zh) |
WO (1) | WO2022105713A1 (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114943585A (zh) * | 2022-05-27 | 2022-08-26 | 天翼爱音乐文化科技有限公司 | 一种基于生成对抗网络的业务推荐方法及系统 |
CN115328062A (zh) * | 2022-08-31 | 2022-11-11 | 济南永信新材料科技有限公司 | 水刺布生产线智能控制系统 |
CN117093715A (zh) * | 2023-10-18 | 2023-11-21 | 湖南财信数字科技有限公司 | 词库扩充方法、系统、计算机设备及存储介质 |
CN117454181A (zh) * | 2023-11-16 | 2024-01-26 | 国网山东省电力公司枣庄供电公司 | 基于级联生成对抗网络的局部放电数据生成方法 |
Family Cites Families (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109190446A (zh) * | 2018-07-06 | 2019-01-11 | 西北工业大学 | 基于三元组聚焦损失函数的行人再识别方法 |
CN111476827B (zh) * | 2019-01-24 | 2024-02-02 | 曜科智能科技(上海)有限公司 | 目标跟踪方法、系统、电子装置及存储介质 |
CN110765866B (zh) * | 2019-09-18 | 2021-02-05 | 新疆爱华盈通信息技术有限公司 | 人脸识别方法和人脸识别设备 |
CN111522985B (zh) * | 2020-04-21 | 2023-04-07 | 易拍全球(北京)科贸有限公司 | 基于深浅层特征提取与融合的古董艺术品图像检索方法 |
CN111930992B (zh) * | 2020-08-14 | 2022-10-28 | 腾讯科技(深圳)有限公司 | 神经网络训练方法、装置及电子设备 |
-
2020
- 2020-11-23 CN CN202011320953.8A patent/CN114528896A/zh active Pending
-
2021
- 2021-11-15 WO PCT/CN2021/130667 patent/WO2022105713A1/zh active Application Filing
- 2021-11-15 JP JP2023531631A patent/JP2023550194A/ja active Pending
- 2021-11-15 KR KR1020237015037A patent/KR20230107558A/ko active Search and Examination
- 2021-11-15 US US18/254,158 patent/US20240037408A1/en active Pending
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114943585A (zh) * | 2022-05-27 | 2022-08-26 | 天翼爱音乐文化科技有限公司 | 一种基于生成对抗网络的业务推荐方法及系统 |
CN114943585B (zh) * | 2022-05-27 | 2023-05-05 | 天翼爱音乐文化科技有限公司 | 一种基于生成对抗网络的业务推荐方法及系统 |
CN115328062A (zh) * | 2022-08-31 | 2022-11-11 | 济南永信新材料科技有限公司 | 水刺布生产线智能控制系统 |
US11853019B1 (en) | 2022-08-31 | 2023-12-26 | Jinan Winson New Materials Technology Co., Ltd. | Intelligent control of spunlace production line using classification of current production state of real-time production line data |
CN117093715A (zh) * | 2023-10-18 | 2023-11-21 | 湖南财信数字科技有限公司 | 词库扩充方法、系统、计算机设备及存储介质 |
CN117093715B (zh) * | 2023-10-18 | 2023-12-29 | 湖南财信数字科技有限公司 | 词库扩充方法、系统、计算机设备及存储介质 |
CN117454181A (zh) * | 2023-11-16 | 2024-01-26 | 国网山东省电力公司枣庄供电公司 | 基于级联生成对抗网络的局部放电数据生成方法 |
Also Published As
Publication number | Publication date |
---|---|
US20240037408A1 (en) | 2024-02-01 |
WO2022105713A1 (zh) | 2022-05-27 |
JP2023550194A (ja) | 2023-11-30 |
KR20230107558A (ko) | 2023-07-17 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114528896A (zh) | 模型训练、数据增强方法、装置、电子设备及存储介质 | |
CN110009093B (zh) | 用于分析关系网络图的神经网络系统和方法 | |
CN110852447B (zh) | 元学习方法和装置、初始化方法、计算设备和存储介质 | |
CN111352965B (zh) | 序列挖掘模型的训练方法、序列数据的处理方法及设备 | |
CN111241287A (zh) | 用于生成对抗文本的生成模型的训练方法及装置 | |
CN110009486B (zh) | 一种欺诈检测的方法、系统、设备及计算机可读存储介质 | |
CN111626349A (zh) | 一种基于深度学习的目标检测方法和系统 | |
CN112364942B (zh) | 信贷数据样本均衡方法、装置、计算机设备及存储介质 | |
CN111178435B (zh) | 一种分类模型训练方法、系统、电子设备及存储介质 | |
CN113987196A (zh) | 一种基于知识图谱蒸馏的知识图谱嵌入压缩方法 | |
JP5453107B2 (ja) | 音声セグメンテーションの方法および装置 | |
CN111062806A (zh) | 个人金融信用风险评价方法、系统和存储介质 | |
CN111241258A (zh) | 数据清洗方法、装置、计算机设备及可读存储介质 | |
Leng et al. | Single-shot augmentation detector for object detection | |
CN112508684A (zh) | 一种基于联合卷积神经网络的催收风险评级方法及系统 | |
CN111412795A (zh) | 测试点设置方案生成方法及装置 | |
CN116468095A (zh) | 神经网络架构搜索方法及装置、设备、芯片、存储介质 | |
CN111159397B (zh) | 文本分类方法和装置、服务器 | |
CN112529303A (zh) | 基于模糊决策的风险预测方法、装置、设备和存储介质 | |
Ein-Dor et al. | Confidence in prediction by neural networks | |
KR102136984B1 (ko) | 기술 거래 서비스 제공방법 및 기술 거래 서버 | |
CN114205459A (zh) | 基于网络切片的异常话单检测方法及装置 | |
Kane | An Automated Approach to Incident Routing and Response | |
CN113344419A (zh) | 适用于中小型企业的科技创新统筹管理方法及系统 | |
CN112347371A (zh) | 基于社交文本信息的资源归还增比方法、装置和电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |