CN111582348A - 条件生成式对抗网络的训练方法、装置、设备及存储介质 - Google Patents
条件生成式对抗网络的训练方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN111582348A CN111582348A CN202010359482.5A CN202010359482A CN111582348A CN 111582348 A CN111582348 A CN 111582348A CN 202010359482 A CN202010359482 A CN 202010359482A CN 111582348 A CN111582348 A CN 111582348A
- Authority
- CN
- China
- Prior art keywords
- target
- training
- condition
- discriminator
- generator
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 156
- 238000000034 method Methods 0.000 title claims abstract description 71
- 239000013598 vector Substances 0.000 claims abstract description 175
- 238000007781 pre-processing Methods 0.000 claims abstract description 18
- 230000006870 function Effects 0.000 claims description 36
- 230000004913 activation Effects 0.000 claims description 18
- 230000003042 antagnostic effect Effects 0.000 claims description 15
- 230000006872 improvement Effects 0.000 claims description 6
- 230000008569 process Effects 0.000 abstract description 31
- 238000009826 distribution Methods 0.000 description 11
- 238000010606 normalization Methods 0.000 description 10
- 238000004891 communication Methods 0.000 description 5
- 238000005457 optimization Methods 0.000 description 5
- 238000004364 calculation method Methods 0.000 description 4
- 238000013135 deep learning Methods 0.000 description 4
- 238000010586 diagram Methods 0.000 description 4
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 3
- 230000008034 disappearance Effects 0.000 description 3
- 238000013473 artificial intelligence Methods 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 2
- 230000008859 change Effects 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 238000004880 explosion Methods 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 238000009827 uniform distribution Methods 0.000 description 2
- 210000004556 brain Anatomy 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 230000001902 propagating effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Abstract
本发明属于生成式对抗网络技术领域,公开了一种条件生成式对抗网络的训练方法、装置、设备及存储介质。该方法通过获取真实样本图片,对真实样本图片进行图像预处理,以获得目标样本图片;对目标样本图片进行分类以获得分类结果,并根据分类结果设置条件向量;获取条件生成式对抗网络;基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及条件向量对条件生成式对抗网络中的生成器和判别器进行设置,以获得目标判别器和目标生成器;基于真实样本图片及条件向量对目标生成器和目标判别器进行训练。在条件生成式对抗网络中引入Wasserstein GAN运行机制,同时完成了稳定训练和进程指标的问题,从而解决了现有技术在条件生成式对抗网络训练时的稳定性不高和效率低的技术问题。
Description
技术领域
本发明涉及生成式对抗网络技术领域,尤其涉及一种条件生成式对抗网络的训练方法、装置、设备及存储介质。
背景技术
随着计算机硬件和神经网络领域的发展,人工智能逐渐得到了人们的重视,也在人们生活中发挥着越来越重要的作用。深度学习源于神经网络的发展,其概念由Hinton等人于2006年提出,其目的是为了模拟人脑进行分析和解释数据。人们希望通过深度学习找到一个深层次的神经网络模型,这个模型可以表示在人工智能应用中遇到的各种数据之中的概率分布,这些应用包括图像处理、自然语言处理等。到目前为止,深度学习中最令人瞩目的成就之一就是判别器,它可以接收一个高纬度输入并将其转化为一个分类标签。深度学习可以分为有监督学习、半监督学习和无监督学习等几类。生成对抗网络就是一种典型的、非常有发展前景的无监督学习,其本质是一个“对抗”的过程,它是由Ian Goodfellow等人于2014年10月提出的一种通过对抗过程估计生成器的神经网络模型。但是原始的生成对抗网络训练不稳定,生成器面临着梯度消失的问题,还经常出现模型崩溃问题(modecollapse)。
条件生成对抗网络(Conditional GAN)是紧接着原始生成对抗网络被提出来的,条件生成对抗网络是给原始生成对抗网络提供一些“暗示(hint)”来提醒原始生成对抗网络应该生成什么样的输出,原始生成对抗网络的生成过程变成基于某些额外信息的生成。这个额外的“暗示”是直接拼接在原始生成对抗网络的输入上实现的,操作十分简单。
人类可以轻松地发现不同域数据之间的关系,但是对于机器来说,想要学习这个关系是非常有挑战性的,有时可能还需要专门制造一些成对的不同域数据传给机器来学习。自从条件生成式对抗网络提出以来,就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性,训练时的稳定性不高和效率低的问题。
上述内容仅用于辅助理解本发明的技术方案,并不代表承认上述内容是现有技术。
发明内容
本发明的主要目的在于提供一种条件生成式对抗网络的训练方法、装置、设备及存储介质,旨在解决现有技术在条件生成式对抗网络训练时的稳定性不高和效率低的技术问题。
为实现上述目的,本发明提供了一种条件生成式对抗网络的训练方法,所述方法包括以下步骤:
获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片;
对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量;
获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器;
基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器;
基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。
优选地,所述基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器的步骤,具体包括:
基于瓦瑟斯坦生成式对抗网络Wasserstein GAN获取瓦瑟斯坦距离参数及梯度惩罚;
生成随机噪声信息;
根据所述随机噪声信息、所述瓦瑟斯坦距离参数、所述梯度惩罚及所述条件向量对所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器。
优选地,所述根据所述随机噪声信息、所述瓦瑟斯坦距离参数、所述梯度惩罚及所述条件向量对所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器的步骤,具体包括:
根据所述随机噪声信息和所述条件向量对所述生成器的输入层进行设置,以获得目标生成器;
根据所述瓦瑟斯坦距离参数及所述梯度惩罚对所述判别器进行设置,并根据所述条件向量在所述判别器的输出层中设置预设维度向量,以获得优化判别器;
去除所述优化判别器的输入层中的条件向量及激活层中的Sigmoid激活函数,以获得目标判别器。
优选地,所述基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练的步骤,具体包括:
保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器;
保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器;
对所述目标判别器和所述目标生成器执行迭代训练的次数进行设置,获得预设迭代次数;
根据所述预设迭代次数对所述目标判别器和所述目标生成器进行训练。
优选地,所述保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器的步骤,具体包括:
保持所述目标生成器各层的参数不变,基于所述条件向量生成第一条件向量;
将所述第一条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第一生成样本图片;
将所述真实样本图片与所述第一生成样本图片输入至所述目标判别器,以获得所述目标判别器输出的第一输出结果;
根据所述第一输出结果与第一目标输出计算所述目标判别器的判别损失;
根据所述判别损失更新所述目标判别器的参数,以实现对所述目标判别器的训练。
优选地,所述保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器的步骤,具体包括:
保持所述目标判别器各层的参数不变,基于所述条件向量生成第二条件向量;
将所述第二条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第二生成样本图片;
将所述第二生成样本图片输入所述目标判别器,以获得所述目标判别器输出的第二输出结果;
根据所述第二输出结果与第二目标输出计算所述目标生成器的生成损失;
根据所述生成损失更新所述目标生成器的参数,以实现对所述目标生成器的训练。
优选地,所述对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量的步骤,具体包括:
对所述目标样本图片进行分类以获得分类结果;
根据所述分类结果生成图片类别数,并将所述图片类别数作为预设维度;
根据所述预设维度设置条件向量,其中,所述条件向量采用One-Hot编码,同一类别的图片对应的条件向量相同。
此外,为实现上述目的,本发明还提出一种条件生成式对抗网络的训练装置,所述装置包括:
图片获取模块,用于获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片;
条件设置模块,用于对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量;
网络获取模块,用于获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器;
网络改进模块,用于基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器;
网络训练模块,用于基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。
此外,为实现上述目的,本发明还提出一种电子设备,所述设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的条件生成式对抗网络的训练程序,所述条件生成式对抗网络的训练程序配置为实现如上文所述的条件生成式对抗网络的训练方法的步骤。
此外,为实现上述目的,本发明还提出一种存储介质,所述存储介质上存储有条件生成式对抗网络的训练程序,所述条件生成式对抗网络的训练程序被处理器执行时实现如上文所述的条件生成式对抗网络的训练方法的步骤。
本发明通过获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片;对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量;获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器;基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器;基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。通过上述方式,基于Wasserstein GAN对条件生成式对抗网络进行改进,同时完成了稳定训练和进程指标的问题,解决了条件生成式对抗网络训练不稳定的问题并加速了条件生成式对抗网络的训练效率,从而解决了现有技术在条件生成式对抗网络训练时的稳定性不高和效率低的技术问题。
附图说明
图1为本发明实施例方案涉和的硬件运行环境的电子设备的结构示意图;
图2为本发明一种条件生成式对抗网络的训练方法第一实施例的流程示意图;
图3a为本发明实施例中条件生成式对抗网络训练20000次后生成的图像;
图3b为本发明实施例中目标判别器和目标生成器训练20000次后生成的图像;
图4a为本发明实施例中条件生成式对抗网络去掉Batch Normalization层后训练20000次后生成的图像;
图4b为本发明实施例中目标判别器和目标生成器去掉Batch Normalization层后训练20000次后生成的图像;
图5为本发明一种条件生成式对抗网络的训练方法第二实施例的流程示意图;
图6为本发明一种条件生成式对抗网络的训练装置第一实施例的结构框图。
本发明目的的实现、功能特点和优点将结合实施例,参照附图做进一步说明。
具体实施方式
应当理解,此处所描述的具体实施例仅用以解释本发明,并不用于限定本发明。
参照图1,图1为本发明实施例方案涉和的硬件运行环境的电子设备结构示意图。
如图1所示,该电子设备可以包括:处理器1001,例如中央处理器(CentralProcessing Unit,CPU),通信总线1002、用户接口1003,网络接口1004,存储器1005。其中,通信总线1002用于实现这些组件之间的连接通信。用户接口1003可以包括显示屏(Display)、输入单元比如键盘(Keyboard),可选用户接口1003还可以包括标准的有线接口、无线接口。网络接口1004可选的可以包括标准的有线接口、无线接口(如无线保真(WIreless-FIdelity,WI-FI)接口)。存储器1005可以是高速的随机存取存储器(RandomAccess Memory,RAM)存储器,也可以是稳定的非易失性存储器(Non-Volatile Memory,NVM),例如磁盘存储器。存储器1005可选的还可以是独立于前述处理器1001的存储装置。
本领域技术人员可以理解,图1中示出的结构并不构成对电子设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
如图1所示,作为一种存储介质的存储器1005中可以包括操作系统、网络通信模块、用户接口模块和条件生成式对抗网络的训练程序。
在图1所示的电子设备中,网络接口1004主要用于与网络服务器进行数据通信;用户接口1003主要用于与用户进行数据交互;本发明电子设备中的处理器1001、存储器1005可以设置在电子设备中,所述电子设备通过处理器1001调用存储器1005中存储的条件生成式对抗网络的训练程序,并执行本发明实施例提供的条件生成式对抗网络的训练方法。
本发明实施例提供了一种条件生成式对抗网络的训练方法,参照图2,图2为本发明一种条件生成式对抗网络的训练方法第一实施例的流程示意图。
本实施例中,所述条件生成式对抗网络的训练方法包括以下步骤:
步骤S10:获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片。
需要说明的是,根据生成目标获取真实样本图片,真实样本图片的获取环节负责采集足够多包含有丰富细节信息的、可供训练的真实样本图片。对所述真实样本图片进行图像预处理可以包括判断获取到的真实样本图片是否清晰、图片内容是否包括人像或风景、需要对真实样本图片加入描述条件以生成指定风格图片等,以获得目标样本图片。
具体地,根据生成目标如生成手写数字,可以使用MNIST手写数据集作为真实样本图片,MNIST手写数据集是灰度图片集,即二维数据;对所述真实样本图片进行图像预处理,MNIST手写数据集作为公开实验数据集,是已经经过预处理的数据集,可以不需要进行图像预处理。
步骤S20:对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量。
易于理解的是,所述对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量的步骤,具体包括:对所述目标样本图片进行分类以获得分类结果;根据所述分类结果生成图片类别数,并将所述图片类别数作为预设维度;根据所述预设维度设置条件向量,其中,所述条件向量采用One-Hot编码,同一类别的图片对应的条件向量相同。
具体地,对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量,例如采用MNIST手写数据集作为真实样本图片时可以不需要进行图像预处理,MNIST手写数据集是关于手写数字0-9的手写数字图像的集合,因此在设置条件向量的过程中:首先将数据集按照0-9的具体数字进行分类以获得分类结果为类别数10类,同时根据类别数将条件向量设置为10维,其中所述条件向量采用One-Hot编码,同一类别的图片对应的条件向量相同,条件向量可以是一个维数为10的向量。
步骤S30:获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器。
需要说明的是,需要说明的是,获取条件生成式对抗网络,所述条件生成式对抗网络包含了两个“对抗”的模型:生成器(G)用于捕捉数据分布,判别器(D)用于指导生成器生成不同条件的数据。条件生成式对抗网络是对生成式对抗网络的扩展,在生成器(D)的建模中引入条件向量,判别器为不同条件的输入分配不同的目标向量,可以指导数据生成过程。条件向量可以是任意信息,例如类别信息,或者其他模态的数据。通过将条件向量输送给生成器作为输入层的一部分,并作为判别器不同的目标向量,从而实现条件生成式对抗网络。
步骤S40:基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器;
易于理解的是,所述基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器的步骤,具体包括:基于瓦瑟斯坦生成式对抗网络Wasserstein GAN获取瓦瑟斯坦距离参数及梯度惩罚;生成随机噪声信息;根据所述随机噪声信息、所述瓦瑟斯坦距离参数、所述梯度惩罚及所述条件向量对所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器。其中,所述根据所述随机噪声信息、所述瓦瑟斯坦距离参数、所述梯度惩罚及所述条件向量对所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器的步骤,具体包括:根据所述随机噪声信息和所述条件向量对所述生成器的输入层进行设置,以获得目标生成器;根据所述瓦瑟斯坦距离参数及所述梯度惩罚对所述判别器进行设置,并根据所述条件向量在所述判别器的输出层中设置预设维度向量,以获得优化判别器;去除所述优化判别器的输入层中的条件向量及激活层中的Sigmoid激活函数,以获得目标判别器。
具体地,通过瓦瑟斯坦生成式对抗网络Wasserstein GAN中的Wasserstein距离参数来衡量真实样本分布和生成样本分布之间的差距,由于Wasserstein距离参数满足两个分布之间没有交集,依然可以衡量样本分布之间的远近。将所述随机噪声信息和所述条件向量设置于生成器的输入层,将生成器的输出层设置为目标生成图片,去除判别器的输入层中的条件向量并将判别器的输出层由原始的一维标量改为n维条件向量,其中采用MNIST手写数据集作为真实样本图片时,将数据集按照0-9的具体数字进行分类以获得分类结果为类别数10类,同时根据类别数将条件向量设置为10维,条件向量可以是一个维数为10的向量,即n可以为10;去除所述判别器中激活层的Sigmoid激活函数,生成器与判别器的激活函数可以分别使用ReLU和Leaky ReLu作为激活函数,上述对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,最终获得目标判别器和目标生成器。
具体地,生成随机噪声信息用作生成器的输入,可以采用TensorFlow框架的内置函数np.random.uniform()生成随机噪声向量,生成方式为:在一个区间为-1到1之间的均匀分布中随机采样100次,随机噪声设置为一个100维数的向量,该向量可看成一个(1,1,100)的向量。
引入Wasserstein距离参数作为损失函数来衡量生成图片与目标图片的距离参数目的是同时完成稳定训练和进程指标的问题。用Wasserstein距离参数代替JS散度来衡量生成图片与目标图片的距离,解决了模式崩溃问题,并且持续地提供的梯度来指示训练的进程,去掉判别器的输出层的sigmoid激活函数,使判别器由解决一个二分类问题变为解决一个回归问题;生成器和判别器的损失函数(loss函数)都不带有对数计算(log计算)。
引入梯度惩罚(Gradient Penalty)以满足Wasserstein距离参数作为损失函数时对鉴别器的1-Lipschtiz限制,梯度惩罚是一种更加先进的Lipschitz限制手法,Lipschitz限制了判别器函数的梯度,使其不大于一个有限的常数K,这样就保证了输入经过微小变化后,输出不会发生剧烈的变化。梯度惩罚是对权重裁剪(Weight Clipping)的一种改进,它可以让梯度在反向传播过程中保持稳定,梯度惩罚的做法是对生成样本图片集中区域、真实样本图片集中区域以及夹在它们中间的区域加以限制,并且直接把目标判别器的梯度限制在1附近,避免在训练过程中可能产生的梯度消失或梯度爆炸现象。具体做法是:在计算目标判别器的判别损失时加入了一个额外项,对大于或小于1的目标判别器梯度施加梯度惩罚。
步骤S50:基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。
需要说明的是,所述基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练的步骤,具体包括:保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器;保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器;对所述目标判别器和所述目标生成器执行迭代训练的次数进行设置,获得预设迭代次数;根据所述预设迭代次数对所述目标判别器和所述目标生成器进行训练。参照图3a,图3a为本发明实施例中条件生成式对抗网络训练20000次后生成的图像,其中,digit表示数字。参照图3b,图3b为本发明实施例中目标判别器和目标生成器训练20000次后生成的图像,其中,digit表示数字。可见,本实施例扩大判别器输出层维度对条件生成式对抗网络进行改进后,目标判别器和目标生成器训练20000次后生成的图像比条件生成式对抗网络更清晰,本实施例同时基于Wasserstein距离参数作为条件生成式对抗网络的损失函数,完成了稳定训练和进程指标的问题,解决了条件生成式对抗网络训练不稳定的问题并加速了条件生成式对抗网络的训练效率。参照图4a,图4a为本发明实施例中条件生成式对抗网络去掉Batch Normalization层后训练20000次后生成的图像;参照图4b,图4b为本发明实施例中目标判别器和目标生成器去掉Batch Normalization层后训练20000次后生成的图像,将条件生成式对抗网络模型和本实施例中目标判别器和目标生成器构成的网络模型,均去掉Batch Normalization层(一种使网络模型更稳定的方式)进行训练得到对照组图4a以及图4b,参照图3a、图3b、图4a以及图4b说明本实施例目标判别器和目标生成器构成的网络模型结构更加稳定,观察对照组图4a以及图4b发现本实施例目标判别器和目标生成器几乎不受Batch Normalization层影响。
具体地,所述保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器的步骤,具体包括:保持所述目标生成器各层的参数不变,基于所述条件向量生成第一条件向量;将所述第一条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第一生成样本图片;将所述真实样本图片与所述第一生成样本图片输入至所述目标判别器,以获得所述目标判别器输出的第一输出结果;根据所述第一输出结果与第一目标输出计算所述目标判别器的判别损失;根据所述判别损失更新所述目标判别器的参数,以实现对所述目标判别器的训练。
具体地,所述保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器的步骤,具体包括:保持所述目标判别器各层的参数不变,基于所述条件向量生成第二条件向量;将所述第二条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第二生成样本图片;将所述第二生成样本图片输入所述目标判别器,以获得所述目标判别器输出的第二输出结果;根据所述第二输出结果与第二目标输出计算所述目标生成器的生成损失;根据所述生成损失更新所述目标生成器的参数,以实现对所述目标生成器的训练。
易于理解的是,基于Wasserstein距离参数的损失函数允许改进的条件生成式对抗网络模型训练最优判别器,可以设置所述目标判别器与所述目标生成器的预设迭代次数为5:1,对所述目标判别器与所述目标生成器进行训练所用的损失函数可以基于Wasserstein距离参数,并通过梯度惩罚作为限制,对所述目标判别器与所述目标生成器进行训练所用的优化器可以为RMS Prop优化算法,代替常用的Adam优化算法。
本实施例通过获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片;对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量;获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器;基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器;基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。通过上述方式,损失函数基于Wasserstein距离来衡量真实数据与生成数据的距离;生成器输入为噪音和条件,输出为生成的图片;判别器输入为生成器生成的图片和真实图片,输出为与条件类别个数一致的n维向量;扩大判别器输出层维度,可以指导不同条件图片的生成且提高生成图片的质量;通过Wasserstein距离参数作为损失函数完成了稳定训练和进程指标的问题,解决了现有技术在条件生成式对抗网络训练时稳定性不高的同时提高了生成图片的质量。基于Wasserstein GAN对条件生成式对抗网络进行改进,同时完成了稳定训练和进程指标的问题,解决了条件生成式对抗网络训练不稳定的问题并加速了条件生成式对抗网络的训练效率,从而解决了现有技术在条件生成式对抗网络训练时的稳定性不高和效率低的技术问题。
参考图5,图5为本发明一种条件生成式对抗网络的训练方法第二实施例的流程示意图。基于上述第一实施例,本实施例条件生成式对抗网络的训练方法在所述步骤S50,具体包括:
S501:保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器。
需要说明的是,所述保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器的步骤,具体包括:保持所述目标生成器各层的参数不变,基于所述条件向量生成第一条件向量;将所述第一条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第一生成样本图片;将所述真实样本图片与所述第一生成样本图片输入至所述目标判别器,以获得所述目标判别器输出的第一输出结果;根据所述第一输出结果与第一目标输出计算所述目标判别器的判别损失;根据所述判别损失更新所述目标判别器的参数,以实现对所述目标判别器的训练。
具体地,保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器,训练所述目标判别器判断一个样本是真实样本图片还是生成器输出的生成样本图片的能力;使真实样本图片拟合相对应的真实标签,所述目标生成器输出的第一生成样本图片拟合相对应的错误标签。
易于理解的是,将所述真实样本图片与所述第一生成样本图片输入至所述目标判别器,衡量所述目标判别器输出的第一输出结果和所述真实样本图片之间的差异,计算判别损失,所述判别损失根据Wasserstein距离参数来计算。将所述判别损失从所述目标判别器的输出层向隐藏层反向传播,直至传播到输入层,在这个过程中使用RMS Prop优化算法对所述目标判别器的参数进行更新。更新完成后,再次使用所述目标判别器对生成样本和真实样本进行鉴别,直到所述目标判别器可以正确区分真实样本图片与第一生成样本图片,此时所述目标判别器训练暂时完成。所述目标判别器中还可以设置一个判别监控器,该判别监控器可以通过所述目标判别器的判别损失来检查所述目标判别器是否具有分辨真实样本图片与第一生成样本图片的能力。
S502:保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器。
易于理解的是,所述保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器的步骤,具体包括:保持所述目标判别器各层的参数不变,基于所述条件向量生成第二条件向量;将所述第二条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第二生成样本图片;将所述第二生成样本图片输入所述目标判别器,以获得所述目标判别器输出的第二输出结果;根据所述第二输出结果与第二目标输出计算所述目标生成器的生成损失;根据所述生成损失更新所述目标生成器的参数,以实现对所述目标生成器的训练。
具体地,保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器,训练所述目标生成器生成第二生成样本图片,并让所述目标判别器无法判断第二生成样本图片是否为所述目标生成器生成的。将所述目标生成器输出的第二生成样本图片拟合对应的真实标签。
需要说明的是,将所述第二生成样本图片输入所述目标判别器,以获得所述目标判别器输出的第二输出结果;根据所述第二输出结果与第二目标输出计算所述目标生成器的生成损失,所述生成损失根据Wasserstein距离参数来计算。
S503:对所述目标判别器和所述目标生成器执行迭代训练的次数进行设置,获得预设迭代次数。
易于理解的是,基于Wasserstein距离参数的损失函数允许改进的条件生成式对抗网络模型训练最优判别器,可以设置所述目标判别器与所述目标生成器的预设迭代次数为5:1,对所述目标生成器和所述目标鉴别器进行训练所用的损失函数可以基于Wasserstein距离参数,并通过梯度惩罚作为限制,对所述目标生成器和所述目标鉴别器进行训练所用的优化器可以为Adam。迭代训练所述目标生成器和所述目标鉴别器,最终所述目标生成器可以根据条件向量生成高质量的图片。
S504:根据所述预设迭代次数对所述目标判别器和所述目标生成器进行训练。
本实施例通过保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器;保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器;对所述目标判别器和所述目标生成器执行迭代训练的次数进行设置,获得预设迭代次数;根据所述预设迭代次数对所述目标判别器和所述目标生成器进行训练。通过上述方式,目标判别器和目标生成器的训练速度会更快,效率会提升,训练效果也会有一定提升,解决了条件生成式对抗网络训练不稳定的问题并加速了训练效率,从而解决了现有技术在条件生成式对抗网络训练时的稳定性不高和效率低的技术问题。
此外,本发明实施例还提出一种存储介质,所述存储介质上存储有条件生成式对抗网络的训练程序,所述条件生成式对抗网络的训练程序被处理器执行时实现如上文所述的条件生成式对抗网络的训练方法的步骤。
参照图6,图6为本发明条件生成式对抗网络的训练装置第一实施例的结构框图。
如图6所示,本发明实施例条件生成式对抗网络的训练装置包括:
图片获取模块10,用于获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片。
需要说明的是,根据生成目标获取真实样本图片,真实样本图片的获取环节负责采集足够多包含有丰富细节信息的、可供训练的真实样本图片。对所述真实样本图片进行图像预处理可以包括判断获取到的真实样本图片是否清晰、图片内容是否包括人像或风景、需要对真实样本图片加入描述条件以生成指定风格图片等,以获得目标样本图片。
具体地,根据生成目标如生成手写数字,可以使用MNIST手写数据集作为真实样本图片,MNIST手写数据集是灰度图片集,即二维数据;对所述真实样本图片进行图像预处理,MNIST手写数据集作为公开实验数据集,是已经经过预处理的数据集,可以不需要进行图像预处理。
条件设置模块20,用于对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量。
易于理解的是,所述对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量的步骤,具体包括:对所述目标样本图片进行分类以获得分类结果;根据所述分类结果生成图片类别数,并将所述图片类别数作为预设维度;根据所述预设维度设置条件向量,其中,所述条件向量采用One-Hot编码,同一类别的图片对应的条件向量相同。
具体地,对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量,例如采用MNIST手写数据集作为真实样本图片时可以不需要进行图像预处理,MNIST手写数据集是关于手写数字0-9的手写数字图像的集合,因此在设置条件向量的过程中:首先将数据集按照0-9的具体数字进行分类以获得分类结果为类别数10类,同时根据类别数将条件向量设置为10维,其中所述条件向量采用One-Hot编码,同一类别的图片对应的条件向量相同,条件向量可以是一个维数为10的向量。
网络获取模块30,用于获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器。
需要说明的是,需要说明的是,获取条件生成式对抗网络,所述条件生成式对抗网络包含了两个“对抗”的模型:生成器(G)用于捕捉数据分布,判别器(D)用于指导生成器生成不同条件的数据。条件生成式对抗网络是对生成式对抗网络的扩展,在生成器(D)的建模中引入条件向量,判别器为不同条件的输入分配不同的目标向量,可以指导数据生成过程。条件向量可以是任意信息,例如类别信息,或者其他模态的数据。通过将条件向量输送给生成器作为输入层的一部分,并作为判别器不同的目标向量,从而实现条件生成式对抗网络。
网络改进模块40,用于基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器。
易于理解的是,所述基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器的步骤,具体包括:基于瓦瑟斯坦生成式对抗网络Wasserstein GAN获取瓦瑟斯坦距离参数及梯度惩罚;生成随机噪声信息;根据所述随机噪声信息、所述瓦瑟斯坦距离参数、所述梯度惩罚及所述条件向量对所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器。其中,所述根据所述随机噪声信息、所述瓦瑟斯坦距离参数、所述梯度惩罚及所述条件向量对所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器的步骤,具体包括:根据所述随机噪声信息和所述条件向量对所述生成器的输入层进行设置,以获得目标生成器;根据所述瓦瑟斯坦距离参数及所述梯度惩罚对所述判别器进行设置,并根据所述条件向量在所述判别器的输出层中设置预设维度向量,以获得优化判别器;去除所述优化判别器的输入层中的条件向量及激活层中的Sigmoid激活函数,以获得目标判别器。
具体地,通过瓦瑟斯坦生成式对抗网络Wasserstein GAN中的Wasserstein距离参数来衡量真实样本分布和生成样本分布之间的差距,由于Wasserstein距离参数满足两个分布之间没有交集,依然可以衡量样本分布之间的远近。将所述随机噪声信息和所述条件向量设置于生成器的输入层,将生成器的输出层设置为目标生成图片,去除判别器的输入层中的条件向量并将判别器的输出层由原始的一维标量改为n维条件向量,其中采用MNIST手写数据集作为真实样本图片时,将数据集按照0-9的具体数字进行分类以获得分类结果为类别数10类,同时根据类别数将条件向量设置为10维,条件向量可以是一个维数为10的向量,即n可以为10;去除所述判别器中激活层的Sigmoid激活函数,生成器与判别器的激活函数可以分别使用ReLU和Leaky ReLu作为激活函数,上述对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,最终获得目标判别器和目标生成器。
具体地,生成随机噪声信息用作生成器的输入,可以采用TensorFlow框架的内置函数np.random.uniform()生成随机噪声向量,生成方式为:在一个区间为-1到1之间的均匀分布中随机采样100次,随机噪声设置为一个100维数的向量,该向量可看成一个(1,1,100)的向量。
引入Wasserstein距离参数作为损失函数来衡量生成图片与目标图片的距离参数目的是同时完成稳定训练和进程指标的问题。用Wasserstein距离参数代替JS散度来衡量生成图片与目标图片的距离,解决了模式崩溃问题,并且持续地提供的梯度来指示训练的进程,去掉判别器的输出层的sigmoid激活函数,使判别器由解决一个二分类问题变为解决一个回归问题;生成器和判别器的损失函数(loss函数)都不带有对数计算(log计算)。
引入梯度惩罚(Gradient Penalty)以满足Wasserstein距离参数作为损失函数时对鉴别器的1-Lipschtiz限制,梯度惩罚是一种更加先进的Lipschitz限制手法,Lipschitz限制了判别器函数的梯度,使其不大于一个有限的常数K,这样就保证了输入经过微小变化后,输出不会发生剧烈的变化。梯度惩罚是对权重裁剪(Weight Clipping)的一种改进,它可以让梯度在反向传播过程中保持稳定,梯度惩罚的做法是对生成样本图片集中区域、真实样本图片集中区域以及夹在它们中间的区域加以限制,并且直接把目标判别器的梯度限制在1附近,避免在训练过程中可能产生的梯度消失或梯度爆炸现象。具体做法是:在计算目标判别器的判别损失时加入了一个额外项,对大于或小于1的目标判别器梯度施加梯度惩罚。
网络训练模块50,用于基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。
需要说明的是,所述基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练的步骤,具体包括:保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器;保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器;对所述目标判别器和所述目标生成器执行迭代训练的次数进行设置,获得预设迭代次数;根据所述预设迭代次数对所述目标判别器和所述目标生成器进行训练。参照图3a,图3a为本发明实施例中条件生成式对抗网络训练20000次后生成的图像。参照图3b,图3b为本发明实施例中目标判别器和目标生成器训练20000次后生成的图像。可见,本实施例扩大判别器输出层维度对条件生成式对抗网络进行改进后,目标判别器和目标生成器训练20000次后生成的图像比条件生成式对抗网络更清晰,本实施例同时基于Wasserstein距离参数作为条件生成式对抗网络的损失函数,完成了稳定训练和进程指标的问题,解决了条件生成式对抗网络训练不稳定的问题并加速了条件生成式对抗网络的训练效率。参照图4a,图4a为本发明实施例中条件生成式对抗网络去掉Batch Normalization层后训练20000次后生成的图像;参照图4b,图4b为本发明实施例中目标判别器和目标生成器去掉Batch Normalization层后训练20000次后生成的图像,将条件生成式对抗网络模型和本实施例中目标判别器和目标生成器构成的网络模型,均去掉Batch Normalization层(一种使网络模型更稳定的方式)进行训练得到对照组图4a以及图4b,参照图3a、图3b、图4a以及图4b说明本实施例目标判别器和目标生成器构成的网络模型结构更加稳定,观察对照组图4a以及图4b发现本实施例目标判别器和目标生成器几乎不受Batch Normalization层影响。
具体地,所述保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器的步骤,具体包括:保持所述目标生成器各层的参数不变,基于所述条件向量生成第一条件向量;将所述第一条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第一生成样本图片;将所述真实样本图片与所述第一生成样本图片输入至所述目标判别器,以获得所述目标判别器输出的第一输出结果;根据所述第一输出结果与第一目标输出计算所述目标判别器的判别损失;根据所述判别损失更新所述目标判别器的参数,以实现对所述目标判别器的训练。
具体地,所述保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器的步骤,具体包括:保持所述目标判别器各层的参数不变,基于所述条件向量生成第二条件向量;将所述第二条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第二生成样本图片;将所述第二生成样本图片输入所述目标判别器,以获得所述目标判别器输出的第二输出结果;根据所述第二输出结果与第二目标输出计算所述目标生成器的生成损失;根据所述生成损失更新所述目标生成器的参数,以实现对所述目标生成器的训练。
易于理解的是,基于Wasserstein距离参数的损失函数允许改进的条件生成式对抗网络模型训练最优判别器,可以设置所述目标判别器与所述目标生成器的预设迭代次数为5:1,对所述目标判别器与所述目标生成器进行训练所用的损失函数可以为基于Wasserstein距离参数,并通过梯度惩罚作为限制,对所述目标判别器与所述目标生成器进行训练所用的优化器可以为RMS Prop优化算法,代替常用的Adam优化算法。
本实施例通过图片获取模块10,用于获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片;条件设置模块20,用于对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量;网络获取模块30,用于获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器;网络改进模块40,用于基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器;网络训练模块50,用于基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。通过上述方式,损失函数基于Wasserstein距离来衡量真实数据与生成数据的距离;生成器输入为噪音和条件,输出为生成的图片;判别器输入为生成器生成的图片和真实图片,输出为与条件类别个数一致的n维向量;扩大判别器输出层维度,可以指导不同条件图片的生成且提高生成图片的质量;通过Wasserstein距离参数作为损失函数完成了稳定训练和进程指标的问题,解决了现有技术在条件生成式对抗网络训练时稳定性不高的同时提高了生成图片的质量。基于Wasserstein GAN对条件生成式对抗网络进行改进,同时完成了稳定训练和进程指标的问题,解决了条件生成式对抗网络训练不稳定的问题并加速了训练效率,从而解决了现有技术在条件生成式对抗网络训练时的稳定性不高和效率低的技术问题。
应当理解的是,以上仅为举例说明,对本发明的技术方案并不构成任何限定,在具体应用中,本领域的技术人员可以根据需要进行设置,本发明对此不做限制。
需要说明的是,以上所描述的工作流程仅仅是示意性的,并不对本发明的保护范围构成限定,在实际应用中,本领域的技术人员可以根据实际的需要选择其中的部分或者全部来实现本实施例方案的目的,此处不做限制。
另外,未在本实施例中详尽描述的技术细节,可参见本发明任意实施例所提供的条件生成式对抗网络的训练方法,此处不再赘述。
此外,需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者系统不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者系统所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者系统中还存在另外的相同要素。
上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如只读存储器(Read Only Memory,ROM)/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。
以上仅为本发明的优选实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书和附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。
Claims (10)
1.一种条件生成式对抗网络的训练方法,其特征在于,所述方法包括:
获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片;
对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量;
获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器;
基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器;
基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。
2.如权利要求1所述的条件生成式对抗网络的训练方法,其特征在于,所述基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器的步骤,具体包括:
基于瓦瑟斯坦生成式对抗网络Wasserstein GAN获取瓦瑟斯坦距离参数及梯度惩罚;
生成随机噪声信息;
根据所述随机噪声信息、所述瓦瑟斯坦距离参数、所述梯度惩罚及所述条件向量对所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器。
3.如权利要求2所述的条件生成式对抗网络的训练方法,其特征在于,所述根据所述随机噪声信息、所述瓦瑟斯坦距离参数、所述梯度惩罚及所述条件向量对所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器的步骤,具体包括:
根据所述随机噪声信息和所述条件向量对所述生成器的输入层进行设置,以获得目标生成器;
根据所述瓦瑟斯坦距离参数及所述梯度惩罚对所述判别器进行设置,并根据所述条件向量在所述判别器的输出层中设置预设维度向量,以获得优化判别器;
去除所述优化判别器的输入层中的条件向量及激活层中的Sigmoid激活函数,以获得目标判别器。
4.如权利要求3所述的条件生成式对抗网络的训练方法,其特征在于,所述基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练的步骤,具体包括:
保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器;
保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器;
对所述目标判别器和所述目标生成器执行迭代训练的次数进行设置,获得预设迭代次数;
根据所述预设迭代次数对所述目标判别器和所述目标生成器进行训练。
5.如权利要求4所述的条件生成式对抗网络的训练方法,其特征在于,所述保持所述目标生成器各层的参数不变,基于所述真实样本图片及所述条件向量训练所述目标判别器的步骤,具体包括:
保持所述目标生成器各层的参数不变,基于所述条件向量生成第一条件向量;
将所述第一条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第一生成样本图片;
将所述真实样本图片与所述第一生成样本图片输入至所述目标判别器,以获得所述目标判别器输出的第一输出结果;
根据所述第一输出结果与第一目标输出计算所述目标判别器的判别损失;
根据所述判别损失更新所述目标判别器的参数,以实现对所述目标判别器的训练。
6.如权利要求4所述的条件生成式对抗网络的训练方法,其特征在于,所述保持所述目标判别器各层的参数不变,基于所述条件向量和所述随机噪声信息训练所述目标生成器的步骤,具体包括:
保持所述目标判别器各层的参数不变,基于所述条件向量生成第二条件向量;
将所述第二条件向量与所述随机噪声信息输入至所述目标生成器,以获得所述目标生成器输出的第二生成样本图片;
将所述第二生成样本图片输入所述目标判别器,以获得所述目标判别器输出的第二输出结果;
根据所述第二输出结果与第二目标输出计算所述目标生成器的生成损失;
根据所述生成损失更新所述目标生成器的参数,以实现对所述目标生成器的训练。
7.如权利要求1所述的条件生成式对抗网络的训练方法,其特征在于,所述对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量的步骤,具体包括:
对所述目标样本图片进行分类以获得分类结果;
根据所述分类结果生成图片类别数,并将所述图片类别数作为预设维度;
根据所述预设维度设置条件向量,其中,所述条件向量采用One-Hot编码,同一类别的图片对应的条件向量相同。
8.一种条件生成式对抗网络的训练装置,其特征在于,所述装置包括:
图片获取模块,用于获取真实样本图片,对所述真实样本图片进行图像预处理,以获得目标样本图片;
条件设置模块,用于对所述目标样本图片进行分类以获得分类结果,并根据所述分类结果设置条件向量;
网络获取模块,用于获取条件生成式对抗网络,所述条件生成式对抗网络包括生成器和判别器;
网络改进模块,用于基于瓦瑟斯坦生成式对抗网络Wasserstein GAN及所述条件向量对所述条件生成式对抗网络中的所述生成器和所述判别器进行设置,以获得目标判别器和目标生成器;
网络训练模块,用于基于所述真实样本图片及所述条件向量对所述目标生成器和所述目标判别器进行训练。
9.一种电子设备,其特征在于,所述设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的条件生成式对抗网络的训练程序,所述条件生成式对抗网络的训练程序配置为实现如权利要求1至7中任一项所述的条件生成式对抗网络的训练方法的步骤。
10.一种存储介质,其特征在于,所述存储介质上存储有条件生成式对抗网络的训练程序,所述条件生成式对抗网络的训练程序被处理器执行时实现如权利要求1至7任一项所述的条件生成式对抗网络的训练方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010359482.5A CN111582348B (zh) | 2020-04-29 | 2020-04-29 | 条件生成式对抗网络的训练方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010359482.5A CN111582348B (zh) | 2020-04-29 | 2020-04-29 | 条件生成式对抗网络的训练方法、装置、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111582348A true CN111582348A (zh) | 2020-08-25 |
CN111582348B CN111582348B (zh) | 2024-02-27 |
Family
ID=72125007
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010359482.5A Active CN111582348B (zh) | 2020-04-29 | 2020-04-29 | 条件生成式对抗网络的训练方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111582348B (zh) |
Cited By (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112329568A (zh) * | 2020-10-27 | 2021-02-05 | 西安晟昕科技发展有限公司 | 一种辐射源信号生成方法、装置及存储介质 |
CN112365557A (zh) * | 2020-11-13 | 2021-02-12 | 北京京东尚科信息技术有限公司 | 图片生成方法、模型训练方法、装置和存储介质 |
CN112541557A (zh) * | 2020-12-25 | 2021-03-23 | 北京百度网讯科技有限公司 | 生成式对抗网络的训练方法、装置及电子设备 |
CN112598034A (zh) * | 2020-12-09 | 2021-04-02 | 华东交通大学 | 基于生成式对抗网络的矿石图像生成方法和计算机可读存储介质 |
CN112613494A (zh) * | 2020-11-19 | 2021-04-06 | 北京国网富达科技发展有限责任公司 | 基于深度对抗网络的电力线路监控异常识别方法及系统 |
CN112766348A (zh) * | 2021-01-12 | 2021-05-07 | 云南电网有限责任公司电力科学研究院 | 一种基于对抗神经网络生成样本数据的方法以及装置 |
CN113505876A (zh) * | 2021-06-11 | 2021-10-15 | 国网浙江省电力有限公司嘉兴供电公司 | 一种基于生成式对抗网络的高压断路器故障诊断方法 |
WO2022126480A1 (zh) * | 2020-12-17 | 2022-06-23 | 深圳先进技术研究院 | 基于Wasserstein生成对抗网络模型的高能图像合成方法、装置 |
CN114863225A (zh) * | 2022-07-06 | 2022-08-05 | 腾讯科技(深圳)有限公司 | 图像处理模型训练方法、生成方法、装置、设备及介质 |
CN115357941A (zh) * | 2022-10-20 | 2022-11-18 | 北京宽客进化科技有限公司 | 一种基于生成式人工智能的去隐私方法和系统 |
CN116010609A (zh) * | 2023-03-23 | 2023-04-25 | 山东中翰软件有限公司 | 一种物料数据归类方法、装置、电子设备及存储介质 |
CN117195743A (zh) * | 2023-10-16 | 2023-12-08 | 西安交通大学 | 一种热障涂层裂纹结构的喷涂参数优化方法 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107563510A (zh) * | 2017-08-14 | 2018-01-09 | 华南理工大学 | 一种基于深度卷积神经网络的wgan模型方法 |
CN109389080A (zh) * | 2018-09-30 | 2019-02-26 | 西安电子科技大学 | 基于半监督wgan-gp的高光谱图像分类方法 |
CN109584337A (zh) * | 2018-11-09 | 2019-04-05 | 暨南大学 | 一种基于条件胶囊生成对抗网络的图像生成方法 |
US20190130903A1 (en) * | 2017-10-27 | 2019-05-02 | Baidu Usa Llc | Systems and methods for robust speech recognition using generative adversarial networks |
CN110070124A (zh) * | 2019-04-15 | 2019-07-30 | 广州小鹏汽车科技有限公司 | 一种基于生成式对抗网络的图像扩增方法及系统 |
WO2019210303A1 (en) * | 2018-04-27 | 2019-10-31 | Carnegie Mellon University | Improved generative adversarial networks having ranking loss |
CN110598806A (zh) * | 2019-07-29 | 2019-12-20 | 合肥工业大学 | 一种基于参数优化生成对抗网络的手写数字生成方法 |
-
2020
- 2020-04-29 CN CN202010359482.5A patent/CN111582348B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107563510A (zh) * | 2017-08-14 | 2018-01-09 | 华南理工大学 | 一种基于深度卷积神经网络的wgan模型方法 |
US20190130903A1 (en) * | 2017-10-27 | 2019-05-02 | Baidu Usa Llc | Systems and methods for robust speech recognition using generative adversarial networks |
WO2019210303A1 (en) * | 2018-04-27 | 2019-10-31 | Carnegie Mellon University | Improved generative adversarial networks having ranking loss |
CN109389080A (zh) * | 2018-09-30 | 2019-02-26 | 西安电子科技大学 | 基于半监督wgan-gp的高光谱图像分类方法 |
CN109584337A (zh) * | 2018-11-09 | 2019-04-05 | 暨南大学 | 一种基于条件胶囊生成对抗网络的图像生成方法 |
CN110070124A (zh) * | 2019-04-15 | 2019-07-30 | 广州小鹏汽车科技有限公司 | 一种基于生成式对抗网络的图像扩增方法及系统 |
CN110598806A (zh) * | 2019-07-29 | 2019-12-20 | 合肥工业大学 | 一种基于参数优化生成对抗网络的手写数字生成方法 |
Non-Patent Citations (2)
Title |
---|
XIANGRUI XU.ET.: "A novel method for identifying the deep neural network model with Serial Number", vol. 4, pages 1 - 9 * |
冯永等: "GP-WIRGAN:梯度惩罚优化的Wasserstein图像循环生成对抗网络模型", vol. 43, no. 2, pages 190 - 205 * |
Cited By (16)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112329568A (zh) * | 2020-10-27 | 2021-02-05 | 西安晟昕科技发展有限公司 | 一种辐射源信号生成方法、装置及存储介质 |
CN112365557A (zh) * | 2020-11-13 | 2021-02-12 | 北京京东尚科信息技术有限公司 | 图片生成方法、模型训练方法、装置和存储介质 |
CN112365557B (zh) * | 2020-11-13 | 2024-04-09 | 北京京东尚科信息技术有限公司 | 图片生成方法、模型训练方法、装置和存储介质 |
CN112613494A (zh) * | 2020-11-19 | 2021-04-06 | 北京国网富达科技发展有限责任公司 | 基于深度对抗网络的电力线路监控异常识别方法及系统 |
CN112598034A (zh) * | 2020-12-09 | 2021-04-02 | 华东交通大学 | 基于生成式对抗网络的矿石图像生成方法和计算机可读存储介质 |
WO2022126480A1 (zh) * | 2020-12-17 | 2022-06-23 | 深圳先进技术研究院 | 基于Wasserstein生成对抗网络模型的高能图像合成方法、装置 |
CN112541557B (zh) * | 2020-12-25 | 2024-04-05 | 北京百度网讯科技有限公司 | 生成式对抗网络的训练方法、装置及电子设备 |
CN112541557A (zh) * | 2020-12-25 | 2021-03-23 | 北京百度网讯科技有限公司 | 生成式对抗网络的训练方法、装置及电子设备 |
CN112766348A (zh) * | 2021-01-12 | 2021-05-07 | 云南电网有限责任公司电力科学研究院 | 一种基于对抗神经网络生成样本数据的方法以及装置 |
CN113505876A (zh) * | 2021-06-11 | 2021-10-15 | 国网浙江省电力有限公司嘉兴供电公司 | 一种基于生成式对抗网络的高压断路器故障诊断方法 |
CN114863225A (zh) * | 2022-07-06 | 2022-08-05 | 腾讯科技(深圳)有限公司 | 图像处理模型训练方法、生成方法、装置、设备及介质 |
CN114863225B (zh) * | 2022-07-06 | 2022-10-04 | 腾讯科技(深圳)有限公司 | 图像处理模型训练方法、生成方法、装置、设备及介质 |
CN115357941A (zh) * | 2022-10-20 | 2022-11-18 | 北京宽客进化科技有限公司 | 一种基于生成式人工智能的去隐私方法和系统 |
CN116010609A (zh) * | 2023-03-23 | 2023-04-25 | 山东中翰软件有限公司 | 一种物料数据归类方法、装置、电子设备及存储介质 |
CN116010609B (zh) * | 2023-03-23 | 2023-06-09 | 山东中翰软件有限公司 | 一种物料数据归类方法、装置、电子设备及存储介质 |
CN117195743A (zh) * | 2023-10-16 | 2023-12-08 | 西安交通大学 | 一种热障涂层裂纹结构的喷涂参数优化方法 |
Also Published As
Publication number | Publication date |
---|---|
CN111582348B (zh) | 2024-02-27 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111582348B (zh) | 条件生成式对抗网络的训练方法、装置、设备及存储介质 | |
Hussain et al. | A real time face emotion classification and recognition using deep learning model | |
Merrick et al. | The explanation game: Explaining machine learning models using shapley values | |
Otten et al. | Event generation and statistical sampling for physics with deep generative models and a density information buffer | |
Lu et al. | Image generation from sketch constraint using contextual gan | |
Fleuret et al. | Comparing machines and humans on a visual categorization test | |
Ozdemir et al. | Feature Engineering Made Easy: Identify unique features from your dataset in order to build powerful machine learning systems | |
Domeniconi et al. | Composite kernels for semi-supervised clustering | |
CN112418320B (zh) | 一种企业关联关系识别方法、装置及存储介质 | |
Walsh et al. | Automated human cell classification in sparse datasets using few-shot learning | |
CN111582136A (zh) | 表情识别方法及装置、电子设备、存储介质 | |
EP3916597A1 (en) | Detecting malware with deep generative models | |
CN111598153B (zh) | 数据聚类的处理方法、装置、计算机设备和存储介质 | |
CN111383217B (zh) | 大脑成瘾性状评估的可视化方法、装置及介质 | |
EP3971773A1 (en) | Visualization method and device for evaluating brain addiction traits, and medium | |
Mejia-Escobar et al. | Towards a better performance in facial expression recognition: a data-centric approach | |
CN112348808A (zh) | 屏幕透图检测方法及装置 | |
CN112749737A (zh) | 图像分类方法及装置、电子设备、存储介质 | |
CN116363732A (zh) | 人脸情绪识别方法、装置、设备及存储介质 | |
Rajeev et al. | Data Augmentation in Classifying Chest Radiograph Images (CXR) Using DCGAN-CNN | |
Vu et al. | c-Eval: A unified metric to evaluate feature-based explanations via perturbation | |
CN115033700A (zh) | 基于相互学习网络的跨领域情感分析方法、装置以及设备 | |
CN110458058B (zh) | 表情的识别方法和装置 | |
Stippinger et al. | BiometricBlender: Ultra-high dimensional, multi-class synthetic data generator to imitate biometric feature space | |
Akça et al. | A Deep Transfer Learning Based Visual Complexity Evaluation Approach to Mobile User Interfaces |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |