CN115329938A - 一种基于鉴别器森林提高生成对抗网络泛化能力的方法 - Google Patents
一种基于鉴别器森林提高生成对抗网络泛化能力的方法 Download PDFInfo
- Publication number
- CN115329938A CN115329938A CN202210994734.0A CN202210994734A CN115329938A CN 115329938 A CN115329938 A CN 115329938A CN 202210994734 A CN202210994734 A CN 202210994734A CN 115329938 A CN115329938 A CN 115329938A
- Authority
- CN
- China
- Prior art keywords
- discriminator
- training
- forest
- generator
- sample
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Data Exchanges In Wide-Area Networks (AREA)
Abstract
本发明公开了一种基于鉴别器森林提高生成对抗网络泛化能力的方法,包括:构建由一个生成器和鉴别器森林组成的生成对抗网络模型;生成训练样本,基于训练样本对鉴别器森林进行训练;通过生成器得到第一生成样本,将第一生成样本输入到训练过的鉴别器森林中的每个鉴别器中,得到鉴别器的损失值,以得到训练梯度,并更新生成器的训练参数,再以最小化的方式对生成器进行训练,得到更新的生成器;基于更新的生成器,重新执行生成训练样本的步骤,以对鉴别器进行下一次迭代训练。本发明提出由鉴别器森林和一个生成器组成的生成对抗网络模型,能够通过鉴别器数量的增加,降低泛化误差上界,提高生成样本的质量和多样性,增强模型的隐私保护能力。
Description
技术领域
本发明涉及机器学习领域,具体涉及一种基于鉴别器森林提高生成对抗网络泛化能力的方法。
背景技术
近年来,随着复杂分布上无监督学习技术的发展,生成对抗网络(GenerativeAdversarialNetworks,GAN)模型,通过生成对抗网络,根据已知的某种类型的样本数据集生成新的样本,在机器学习领域得到了广泛的应用。GAN由一个生成器和一个判别器构成,使用对抗学习的方式迭代训练生成器与判别器,最终估测出训练数据的分布,并利用训练好的生成器模型生成新样本。
但是,在当前基于GAN模型的研究中,很多工作都是针对解决训练不稳定的问题,极少研究者会关心GAN模型的泛化能力,但是缺少泛化能力的GAN模型很容易出现模型坍塌和隐私泄露等问题;因为性能强的鉴别器会引导生成器走向坍塌均衡,容量小的鉴别器会使生成器记忆真实数据,导致GAN模型容易受到潜在攻击和隐私泄露等问题。现有技术中对于带有多鉴别器的GAN模型的泛化能力依然有待提高。
因此,现有技术还有待于改进和发展。
发明内容
本发明要解决的技术问题在于,针对现有技术的上述缺陷,提供一种基于鉴别器森林提高生成对抗网络泛化能力的方法,提出带有多鉴别器的GAN模型,并证明其泛化能力。
为了解决上述技术问题,本发明解决技术问题所采用的技术方案如下:
第一方面,本发明提供一种基于鉴别器森林提高生成对抗网络泛化能力的方法,其中,所述方法包括:
构建由一个生成器和鉴别器森林组成的生成对抗网络模型;
生成训练样本,基于所述训练样本对所述鉴别器森林中的每个鉴别器分别进行训练,得到训练过的鉴别器森林;
通过所述生成器得到第一生成样本,将所述第一生成样本输入到所述训练过的鉴别器森林中的每个鉴别器中,得到所述鉴别器的损失值;
根据所述鉴别器的损失值得到训练梯度,根据所述训练梯度更新所述生成器的训练参数,并以最小化的方式对所述生成器进行训练,得到更新的生成器;
基于所述更新的生成器,重新执行所述生成训练样本的步骤,以对所述鉴别器进行下一次迭代训练,直到完成迭代次数为止。
在一种实现方式中,所述构建由一个生成器和鉴别器森林组成的生成对抗网络模型之前,包括:
构建所述生成器的神经网络结构,在卷积层和卷积层之间采用批量归一化进行归一化处理,并采用泄露的修正线性单元进行激活;
为所有鉴别器构建相同的神经网络结构,在每个鉴别器的卷积层和卷积层之间采用批量归一化进行归一化处理,并采用泄露的修正线性单元进行激活。
在一种实现方式中,所述生成训练样本,包括:
获取真实样本,并根据所述真实样本构建原始样本集,通过Bootstrap采样方法进行随机有放回地采样,得到与若干鉴别器一一对应的若干训练数据集;
将随机采样的128维的高斯噪声数据输入所述生成器,得到第二生成样本;
根据所述训练数据集中的真实样本和所述第二生成样本为每个鉴别器分别生成独立的训练样本;
基于所述训练样本,针对每一个鉴别器分别利用Adam优化器,采用交叉熵损失函数以最大化的方式进行训练,得到所述训练过的鉴别器森林。
在一种实现方式中,所述根据所述鉴别器的损失值得到训练梯度,包括:
将全部鉴别器的损失值进行均值汇总,得到均值汇总损失值;
根据所述均值汇总损失值计算所述训练梯度。
在一种实现方式中,所述根据所述鉴别器的损失值得到训练梯度,包括:
将全部鉴别器的损失值进行加权汇总,得到加权汇总损失值;
根据所述加权汇总损失值计算所述训练梯度。
在一种实现方式中,所述方法还包括:
设置所述生成器和所述鉴别器森林在一次迭代训练中的训练次数比例为1:1。
第二方面,本发明实施例还提供一种基于鉴别器森林提高生成对抗网络泛化能力的装置,其中,所述装置包括:
生成对抗网络模型构建模块,用于构建由一个生成器和鉴别器森林组成的生成对抗网络模型;
鉴别器训练模块,用于生成训练样本,基于所述训练样本对所述鉴别器森林中的每个鉴别器分别进行训练,得到训练过的鉴别器森林;
损失值获取模块,用于通过所述生成器得到第一生成样本,将所述第一生成样本输入到所述训练过的鉴别器森林中的每个鉴别器中,得到所述鉴别器的损失值;
生成器训练模块,用于根据所述鉴别器的损失值得到训练梯度,根据所述训练梯度更新所述生成器的训练参数,并以最小化的方式对所述生成器进行训练,得到更新的生成器;
更新迭代模块,用于基于所述更新的生成器,重新执行鉴别器训练模块中的内容,以对所述鉴别器进行下一次迭代训练,直到完成迭代次数为止。
第三方面,本发明实施例还提供一种智能终端,其中,所述智能终端包括存储器、处理器及存储在所述存储器中并可在所述处理器上运行的基于鉴别器森林提高生成对抗网络泛化能力程序,所述处理器执行所述基于鉴别器森林提高生成对抗网络泛化能力程序时,实现如以上任一项所述的基于鉴别器森林提高生成对抗网络泛化能力的方法的步骤。
第四方面,本发明实施例还提供一种计算机可读存储介质,其中,所述计算机可读存储介质上存储有基于鉴别器森林提高生成对抗网络泛化能力程序,所述基于鉴别器森林提高生成对抗网络泛化能力程序被处理器执行时,实现如以上任一项所述的基于鉴别器森林提高生成对抗网络泛化能力的方法的步骤。
有益效果:与现有技术相比,本发明提供了一种基于鉴别器森林提高生成对抗网络泛化能力的方法,本发明首先构建由一个生成器和鉴别器森林组成的生成对抗网络模型,再基于相互独立的训练样本对鉴别器分别进行训练,得到更新的鉴别器森林,然后,在生成器训练阶段,先基于生成样本对每个鉴别器进行训练得到每个鉴别器的损失值,再将所述鉴别器的损失值进行汇总得到的训练梯度回传给生成器,进而完成生成器的更新,这样有助于通过对生成器的训练,实现降低泛化误差上界,提高生成器的泛化能力,实现了通过鉴别器数量的增加以增强隐私保护能力的效果。最后,再基于更新的生成器,对鉴别器进行下一次迭代训练,以实现对生成对抗网络模型的优化。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明中记载的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例提供的基于鉴别器森林提高生成对抗网络泛化能力的方法的流程图。
图2是本发明实施例提供的生成对抗网络模型的示意图。
图3是本发明实施例提供的构建训练数据集的方法示意图。
图4是本发明实施例提供的生成训练样本的方法示意图。
图5是本发明实施例提供的对鉴别器森林进行训练的方法示意图。
图6是本发明实施例提供的生成对抗网络模型的密度和分布的指标分析图。
图7是本发明实施例提供的生成对抗网络模型的准确率和召回率的指标分析图。
图8是本发明实施例提供基于条件的Forest-GAN的生成样本的指标分析图。
图9是本发明实施例提供的基于鉴别器森林提高生成对抗网络泛化能力的装置的原理框图。
图10是本发明实施例提供的智能终端的内部结构原理框图。
具体实施方式
为使本发明的目的、技术方案及效果更加清楚、明确,以下参照附图并举实施例对本发明进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本发明,并不用于限定本发明。
随着计算机生成图片技术取得空前发展,计算机已经能够自动生成高清晰度的数字图像,并且效果逼真难以通过肉眼区分。与传统基于计算机图形图像学的图像生成算法不同,采用生成对抗网络(GenerativeAdversarial Network,简称GAN)模型从大量图像样本中学习自然图像的模式,能够生成更加逼真的图像。
生成对抗网络(GAN)模型是一种深度学习模型,包含两个部分:一个是生成器(Generator),一个是判别器(Discriminator)。判别器的任务是判断生成模型生成的样本是真实的还是伪造的。生成器的任务是生成看起来逼真与原始数据相似的样本。也就是说,生成器要生成能骗过判别器的实例,而判别器要从真假混合的样本中揪出由生成器生成的伪造样本。生成器和判别器的训练过程是一个对抗博弈的过程,最后博弈的结果是在最理想的状态下,生成器可以生成足以“以假乱真”的样本。
然而,现有技术中通过最小化与目标分布的Jensen-Shannon散度,普通的GAN模型很难完全拟合复杂数据的真实分布。当使用普通的对抗损失函数来训练鉴别器时,性能强的鉴别器会导致模型坍塌,因为鉴别器会过度拟合训练数据,从而导致生成器走向坍塌均衡。容量小的鉴别器能够提高鉴别器的泛化能力,但是这会导致另一个问题:GAN模型中的生成器会通过记忆训练数据来得到低容量的鉴别器,而不是学习、逼近真实的数据分布,这会导致GAN模型容易受到攻击和隐私数据泄露等问题。虽然有很多相关的变体技术能够提高生成样本的质量,但是它们多数关心单个鉴别器GAN模型的泛化能力,并没有讨论多鉴别器GAN模型的泛化能力等问题。
因此,为了解决上述问题,本实施例提供一种基于鉴别器森林提高生成对抗网络泛化能力的方法,通过构建由一个生成器和鉴别器森林组成的生成对抗网络模型,实现了通过鉴别器数量的增加以增强隐私保护能力的效果,再基于相互独立的训练样本对鉴别器分别进行训练,然后,在生成器训练阶段,得到每个鉴别器的损失值,再将所述鉴别器的损失值进行汇总得到的训练梯度回传给生成器,进而完成生成器的更新,这样有助于通过对生成器的训练实现提高生成样本的多样性和质量并降低泛化误差。最后,再基于更新的生成器,对鉴别器进行下一次迭代训练,以实现对生成对抗网络模型的自我优化。本发明提出由鉴别器森林和一个生成器组成的生成对抗网络模型,能够通过提高生成样本的多样性和质量来降低泛化误差,并通过鉴别器数量的增加以增强隐私保护能力,通过生成器和判别器不断地进行对抗式训练,最终使得两个网络达到了一个动态平衡。
示例性方法
本实施例提供一种基于鉴别器森林提高生成对抗网络泛化能力的方法。如图1所示,所述方法包括如下步骤:
步骤S100、构建由一个生成器和鉴别器森林组成的生成对抗网络模型。
具体地,在生成对抗网络中,鉴别器的泛化能力决定了生成器的泛化能力。为了增强鉴别器的泛化能力,限制鉴别器的泛化误差上界,我们采用随机森林的方法来构建由多个鉴别器组成的鉴别器森林,即得到了Forest-GAN模型。在一种实施方式中,将鉴别器的个数定义为K,图2展示了生成对抗网络模型的框架结构。
举例说明,如图2所示,若鉴别器的个数为K,可将K个鉴别器分别标记为D1,D2,...,DK,。这样,由一个生成器和K个鉴别器就组成了Forest-GAN模型。
在一种实现方式中,本实施例所述步骤S100之前包括如下步骤:
步骤S10、构建所述生成器的神经网络结构,在卷积层和卷积层之间采用批量归一化进行归一化处理,并采用泄露的修正线性单元进行激活;
步骤S20、为所有鉴别器构建相同的神经网络结构,在每个鉴别器的卷积层和卷积层之间采用批量归一化进行归一化处理,并采用泄露的修正线性单元进行激活。
具体地,无论是生成器还是鉴别器,在卷积层与卷积层之间,都使用了批量归一化(BatchNormalization)和泄露的修正线性单元(LeakyRectified LinearUnit,LeakyReLU)。即利用批量归一化处理,提升神经网络模型对图像特征的辨识能力,进而可有效增强神经网络模型的泛化性能,并将上一层卷积输出的特征和残差块阶段内递归计算后输出的特征都由泄露的修正线性单元函数进行激活,并输入到下一层卷积。特别需要注意的是:本实施例中Forest-GAN模型的鉴别器森林中每个鉴别器的网络结构都是一样的。
步骤S200、生成训练样本,基于所述训练样本对所述鉴别器森林中的每个鉴别器分别进行训练,得到训练过的鉴别器森林;
具体地,在训练GAN时,最终达到的理想状态是生成器与鉴别器之间达到平衡。这时,鉴别器无法区分真实数据与生成器生成的数据,因为生成器已经学会了生成看起来足以以假乱真的数据。在达到平衡的过程中,每一次迭代训练都需要计算鉴别器的损失值(loss)。可以通过原生GAN模型的损失函数、WGAN(WassersteinGAN)的损失函数/均方误差(mean squarederror,MSE)损失函数或二元交叉熵(binarycrossentropy,BCE)损失函数等方法计算所述鉴别器的损失值。本实施例通过为每一个鉴别器分别生成训练样本,即每一个鉴别器的训练样本都是相互独立的,再基于训练样本对鉴别器分别进行训练,以得到Forest-GAN模型中每个鉴别器各自的损失值。
在一种实现方式中,本实施例所述步骤S200包括如下步骤:
步骤S201、获取真实样本,并根据所述真实样本构建原始样本集,通过Bootstrap采样方法进行随机有放回地采样,得到与若干鉴别器一一对应的若干训练数据集;
步骤S202、将随机采样的128维的高斯噪声数据输入所述生成器,得到第二生成样本;
步骤S203、根据所述训练数据集中的真实样本和所述第二生成样本为每个鉴别器分别生成独立的训练样本;
步骤S204、基于所述训练样本,针对每一个鉴别器分别利用Adam优化器,采用交叉熵损失函数以最大化的方式进行训练,得到所述训练过的鉴别器森林。
具体地,在样本数量为m的真实样本构建原始样本集上,我们使用“Bootstrap”的采样方法在原始样本集上随机有放回地采样m个样本作为鉴别器森林中某个鉴别器的训练数据集,其他鉴别器的训练数据集也是使用同样的方法得到,如图3所示,鉴别器Di,i=1,...,K,所对应的训练数据集记为di,i=1,...,K。根据我们知道原始数据中任何一个元素在任意一个训练数据集上出现的概率是通过“Bootstrap”采样方法,我们可以得到K个相互独立的训练数据集,每个训练数据集对应鉴别器森林的一个鉴别器,这些训练数据集为鉴别器提供真实样本。
本实施例中,生成器会为每一个鉴别器提供生成样本,如图4所示,设定均值为0.0,方差为1.0的高斯分布,从该高斯分布上,将随机采样的128维的高斯噪声数据输入所述生成器,得到第二生成样本。将生成样本标记为0,将真实样本标记为1,将带标签的第二生成样本和真实样本构成训练样本共同作为鉴别器的输入。
在一种实现方式中,可将划分原始数据集为若干子集,在子集上进行“Bootstrap”采样以得到每个鉴别器训练数据集。
在一种实现方式中,若原始数据包含多个数据集,依次以相等或者不同的概率在各个数据集上使用“Bootstrap”采样,以得到每个鉴别器训练数据集。
具体地,本实施例使用K个互相独立的训练数据集对鉴别器森林的鉴别器进行独立训练,也会得到K个相互独立的鉴别器。在训练鉴别器阶段,本实施例将每个鉴别器的训练样本输入到对应的鉴别器以对鉴别器森林进行独立训练;采用交叉熵损失函数,以最大化的方式来训练鉴别器,让鉴别器能够以更大的概率正确区分输入数据的来源,并得到每一个鉴别器的损失值。
举例说明,如图5所示,基于训练数据集训练每个鉴别器Di,i=1,...,K,选择优化器是Adam,初始的学习率、一阶估计衰减率和二阶估计衰减率分别是:1e-4,0.5,0.999。在鉴别器的训练阶段,采用交叉熵损失函数,以最大化的方式来训练鉴别器,所述鉴别器将输入的训练样本进行鉴别,并标记鉴别结果为真实样本(Real)或生成样本(Fake)。将所述鉴别结果与生成器已标记0/1值进行比较,就得到了每一个鉴别器的损失值(loss),因鉴别器与鉴别器之间不共享损失值,不共享回传梯度,因此得到的鉴别器也会是相互独立。
在一种实现方式中,在对原始样本集进行采样以构建一个鉴别器的训练数据集时,还可采用概率采样,即通过遍历原始样本集中的每一个样本,以概率p接受该样本,以概率1-p拒绝该样本的方式进行采样。按照这样的概率采样方式,可以得到鉴别器训练数据集的大小约为:N·p,其中N是原始数据的大小。以相同的方法重复采样K次,就能够获得K个鉴别器所对应的训练数据集。
步骤S300、通过所述生成器得到第一生成样本,将所述第一生成样本输入到所述训练过的鉴别器森林中的每个鉴别器中,得到所述鉴别器的损失值;
具体地,本实施例设定均值为0.0,方差为1.0的高斯分布,从该高斯分布上随机采样128维数据,用作生成器的输入,得到第一生成样本,并将得到的第一生成样本标记真实标签为0,输入到鉴别器森林中,得到预测结果。对于每个鉴别器,我们采用交叉熵损失函数来计算预测结果和真实标签之间的损失值,即得到所述鉴别器的损失值。
步骤S400、根据所述鉴别器的损失值得到训练梯度,根据所述训练梯度更新所述生成器的训练参数,并以最小化的方式对所述生成器进行训练,得到更新的生成器;
具体地,在训练生成器时,需要使用从鉴别器森林传递回来的训练梯度来训练生成器,而训练梯度是根据鉴别器的损失值得到的。将所述训练梯度传递给生成器,并以最小化的方式对所述生成器进行训练。
在一种实现方式中,本实施例所述步骤S400包括如下步骤:
步骤S401、将全部鉴别器的损失值进行均值汇总,得到均值汇总损失值;
步骤S402、根据所述均值汇总损失值计算所述训练梯度。
具体地,在训练生成器阶段,因为需要使用从鉴别器回传到生成器的训练梯度。如图5所示,本实施例使用“aggregation”的方式汇总鉴别器森林中K个鉴别器所产生的损失值,通过损失值计算训练梯度,进而完成对生成器的更新。本实施例采用均值汇总的方式得到鉴别器森林的均值汇总损失值loss,即:
再根据得到的均值汇总鉴别器森林中所产生的损失值计算训练梯度,以进行对生成器的训练。
举例说明,若Forest-GAN模型的鉴别器森林由3个鉴别器Di,i=3构成,通过对鉴别器分别进行训练,得到每个鉴别器的损失值为loss1=0.01,loss2=0.03,loss3=0.01,对3个鉴别器的损失值进行均值汇总,得到鉴别器森林的损失值为(0.01+0.03+0.01)/3=0.02。
在一种实现方式中,本实施例所述步骤S400包括如下步骤:
步骤M401、将全部鉴别器的损失值进行加权汇总,得到加权汇总损失值;
步骤M402、根据所述加权汇总损失值计算所述训练梯度。
具体地,在将全部鉴别器的损失值进行汇总时,对loss值的最大值、最小值、平均值、中位数等做加权和,权重的大小可以相等,也可以不相等,还可以伴随训练过程逐渐变化,以得到加权汇总损失值,再根据加权汇总损失值计算训练梯度。
在一种实现方式中,还可以采用最大/最小/中值化的方式将全部鉴别器的损失值进行汇总,即采用鉴别器森林中所得到的最大/最小/中位数数值loss值作为汇总损失值,再根据所述汇总损失值计算所述训练梯度。
步骤S500、基于所述更新的生成器,重新执行所述生成训练样本的步骤,以对所述鉴别器进行下一次迭代训练,直到完成迭代次数为止。
如图5所示,基于更新后的生成器,重新生成训练样本,并基于所述训练样本对每个鉴别器分别进行训练,得到每个鉴别器的损失值,以执行对鉴别器森林的下一次迭代训练。
本实施例中的Forest-GAN模型还能够和稳定训练的方法相结合,比如梯度惩罚(gradientpenalty)、批量归一化(BatchNormalization)、普归一化(Spectralnormalization)和R1正则化(R1regularization)等稳定训练过程、预防模型坍塌的训练方法结合,以得到更好的训练效果。
在一种实现方式中,本实施例还包括如下步骤:
步骤M10、设置所述生成器和所述鉴别器森林在一次迭代训练中的训练次数比例为1:1。
进一步地,对本实施例中的Forest-GAN模型提高泛化能力的效果进行验证。
其中,JSD是训练数据与生成数据之间的Jensen-shannon散度常量,所以生成数据所对应的分布需要不断近似等价于K个鉴别器训练数据集所表示的混合分布。
另外,因使用随机森林的方式来构建鉴别器的训练数据集,为此可以证明鉴别器森林的误差上界ΨD表示为:
式中,表示鉴别器森林中所有任意两个鉴别器之间关联系数的平均值,s表示鉴别器森林的鉴别能力。又因为生成器的误差上界不大于鉴别器森林的误差上界,为此我们确定鉴别器森林的泛化误差上界就是Forest-GAN模型的泛化误差上界。
在实验方面,本实施例通过三个学习任务来间接展示Forest-GAN模型的泛化能力:模型覆盖、密度估计和隐私保护。本实施例将在两类数据上进行实验:模拟数据和真实数据。本实施例通过数据分布的密度和准确率(Precision)与召回率(Recall)的定量计算来评估模型对真实数据的覆盖率。更高的准确率表示生成样本更加接近真实分布,更高的召回率表示更大的模型覆盖。
其中,对于模拟数据的构成,构建9个二维高斯分布,从这些二维高斯分布中均匀筛选10000个样本点;这些二维高斯分布的均值和方差依次满足:[-2,2],[0,2],[2,2],[-2,0],[0,0],[2,0],[-2,0],[0,-2]和[2,-2],方差依次对应于0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08和0.09;我们使用准确率(precision)和召回率(recall)来反映模型覆盖,可视化生成数据和真实数据的密度分布,然后分析Forest-GAN对真实数据的密度分布的拟合程度。在实验中,除了改变鉴别器的个数(K=1,2,5,10,20,50),本实施例还改变Forest-GAN模型的损失函数,使用两种常见的损失函数:原生GAN模型的损失函数和WGAN模型的损失函数。
图6为模拟数据和生成数据的分布和密度估计。当K=1,不管使用WGAN模型的损失函数还是使用原生GAN模型的损失函数,生成的结果都不能学习到真实分布中的9种模式;但是随着K逐渐增大,Forest-GAN逐渐能够学习到不同的模式,并且当K>20时,Forest-GAN更准确地学习到真实分布中的9种模式;在K=50时,表现出最好的性能,准确率和召回率分别达到0.8938和0.9064。相比于使用原生GAN模型的损失函数,使用WGAN损失函数的Forest-GAN具有更好的表现性能。在K≥2就能够较好地学习到真实分布的9种模态,在K=50时,表现出更好的性能,准确率和召回率分别达到0.9150和0.9231。模型可能会捕获到真实数据的所有模态,但不一定能够捕获到每个模态的密度分布。在图6的密度分布图中,随着鉴别器数量的不断提升,我们的模型能够更加准确地捕获到数据的分布情况。并且在K=50时,生成数据的密度分布能够最大程度上接近真实数据的密度分布。
图7展示了在训练过程中Forest-GAN在两种损失函数、不同鉴别器数量的设定下,生成样本和真实样本之间的准确率和召回率。可知,在使用原生GAN损失函数时,当K≤2,训练的迭代次数超过1000之后,召回率有下降的趋势,这意味着在鉴别器数量很少的情况下,Forest-GAN容易出现过拟合,生成模型会出现模型坍塌的趋势。当鉴别器的数量逐步提高,Forest-GAN能够避免过拟合,并引导生成器生成更多样化的样本。
此外,本实施例实现了基于条件的Forest-GAN(ConditionForest-GAN),在MNIST数据集上进行实验。
图8-A展示了训练过程中,真实样本与生成样本之间的准确率和召回率;召回率曲线随着训练过程先增加,再缓缓下降,但是随着鉴别器的数量增大,召回率在后期下降的趋势并不明显,这意味着,鉴别器数量的逐渐增加能够帮助模型缓解过拟合、缓解模型坍塌问题。图8-B展示了K=50时,基于条件的Forest-GAN按标签生成的样本经过t-SNE映射后得到的分布图,图8-C表示的是生成分布和真实分布在经过t-SNE映射后,在二维坐标上的密度分布图。图8-B和图8-C展示了生成样本十分接近真实样本。
最后,针对Forest-GAN,进行与隐私保护相关实验。在CIFAR-10数据集上评估Forest-GAN模型的隐私保护能力。在训练数据集和对应的非训练数据集上,通过泛化差距、和针对判别森林的白盒攻击来衡量泄露隐私的风险。局限于计算资源,只能尝试K=1,2,5这三种鉴别器设置。正如表1所示,在数据集CIFAR10上,Forest-GAN模型的平均泛化差距和受到白盒攻击的平均准确率。每一个值都是三次实验的平均结果;值越小,表示隐私保护能力越强。随着鉴别器数量的增加,白盒攻击的平均准确率逐渐下降,平均泛化差距也在逐步下降,这意味着:随着鉴别器数量的增加,Forest-GAN能够逐渐增强隐私保护能力。
表1.隐私保护实验数据表
鉴别器个数 | 平均差距 | 平均准确率 |
K=1 | 0.2104 | 0.6621 |
K=2 | 0.1840 | 0.6310 |
K=5 | 0.1154 | 0.5793 |
综上,本实施例使用边缘函数(marginfunction)来定义GAN模型的泛化误差度量,从理论上证明了Forest-GAN模型的泛化误差上界,分析了泛化误差边界与鉴别器的泛化能力和鉴别器之间的关联性逐渐的关联性。由于不能直接展示GAN模型的泛化能力,为此在密度估计、模型覆盖和个性化攻击这三个任务上间接展示Forest-GAN模型的泛化能力以及它的数据隐私保护能力。并且在虚拟数据和真实数据上,通过实验结果,进一步佐证Forest-GAN能够通过提高生成样本的多样性和质量来降低泛化误差,通过抵抗个性化攻击(membershipinferenceattack,MIA)来展示隐私保护能力。
本实施例证明Forest-GAN模型的泛化误差上界,并且论证这个误差上界是由独立鉴别器的泛化能力和鉴别器之间的相关性共同确定的;从理论上,为如何减低GANs的泛化误差上限提供了研究的理论基础。基于Forest-GAN模型的全局最优解近似“Bootstrap”训练数据集的混合分布这一结果表明:即使训练数据有限,通过记忆训练数据,Forest-GAN并不能达到全局最优解的。因此,当Forest-GAN达到全局最优时,生成器能够产生多样化的样本和抵御个性化攻击。本实施例证明了Forest-GAN模型的泛化误差是小于判别森林的泛化误差的。为此,我们可以通过提高鉴别器的泛化能力来提高Forest-GAN模型的泛化能力。Forest-GAN是灵活易变的,不仅可以与任意的损失函数、权重正则化方法相结合,还能够进行并行计算,这些特点使得Forest-GAN能够在适用于分布式机器学习和联邦学习。
示例性装置
在一种实现方式中,本实施例所述步骤S10包括如下步骤:
如图9中所示,本实施例还提供一种基于鉴别器森林提高生成对抗网络泛化能力的装置,所述装置包括:
生成对抗网络模型构建模块10,用于构建由一个生成器和鉴别器森林组成的生成对抗网络模型;
鉴别器训练模块20,用于生成训练样本,基于所述训练样本对所述鉴别器森林中的每个鉴别器分别进行训练,得到训练过的鉴别器森林;
损失值获取模块30,用于通过所述生成器得到第一生成样本,将所述第一生成样本输入到所述训练过的鉴别器森林中的每个鉴别器中,得到所述鉴别器的损失值;
生成器训练模块40,用于根据所述鉴别器的损失值得到训练梯度,根据所述训练梯度更新所述生成器的训练参数,并以最小化的方式对所述生成器进行训练,得到更新的生成器;
更新迭代模块50,用于基于所述更新的生成器,重新执行鉴别器训练模块中的内容,以对所述鉴别器进行下一次迭代训练,直到完成迭代次数为止。
在一种实现方式中,所述基于鉴别器森林提高生成对抗网络泛化能力的装置包括:
生成器构建单元,用于构建所述生成器的神经网络结构,在卷积层和卷积层之间采用批量归一化进行归一化处理,并采用泄露的修正线性单元进行激活;
鉴别器构建单元,用于为所有鉴别器构建相同的神经网络结构,在每个鉴别器的卷积层和卷积层之间采用批量归一化进行归一化处理,并采用泄露的修正线性单元进行激活。
在一种实现方式中,所述鉴别器训练模块20包括:
第一训练数据集获取单元,用于获取真实样本,并根据所述真实样本构建原始样本集,通过Bootstrap采样方法进行随机有放回地采样,得到与若干鉴别器一一对应的若干训练数据集;
第二生成样本获取单元,用于将随机采样的128维的高斯噪声数据输入所述生成器,得到第二生成样本;
第二训练样本获取单元,用于根据所述训练数据集中的真实样本和所述第二生成样本为每个鉴别器分别生成独立的训练样本;
鉴别器森林训练单元,用于基于所述训练样本,针对每一个鉴别器分别利用Adam优化器,采用交叉熵损失函数以最大化的方式进行训练,得到所述训练过的鉴别器森林。
在一种实现方式中,所述损失值获取模块30包括:
第一生成样本获取单元,用于设定高斯分布,从所述高斯分布上随机采样128维数据作为所述生成器的输入,得到所述第一生成样本;其中,所述第一生成样本带有生成标签;
损失值获取单元,用于将所述第一生成样本输入到所述每个鉴别器中得到预测结果,采用交叉熵损失函数得到所述预测结果和所述生成标签之间的损失值。
在一种实现方式中,所述生成器训练模块40包括:
均值汇总单元,用于将全部鉴别器的损失值进行均值汇总,得到均值汇总损失值;
加权汇总单元,用于将全部鉴别器的损失值进行加权汇总,得到加权汇总损失值;
训练梯度获取单元,用于根据所述均值汇总损失值计算所述训练梯度。
在一种实现方式中,所述基于鉴别器森林提高生成对抗网络泛化能力的装置还包括:
训练次数设置单元,用于设置所述生成器和所述鉴别器森林在一次迭代训练中的训练次数比例为1:1。
基于上述实施例,本发明还提供了一种智能终端,其原理框图可以如图10所示。该能终端包括通过系统总线连接的处理器、存储器、网络接口、显示屏、温度传感器。其中,该能终端的处理器用于提供计算和控制能力。该能终端的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统和计算机程序。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该智能终端的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种基于鉴别器森林提高生成对抗网络泛化能力的方法。该智能终端的显示屏可以是液晶显示屏或者电子墨水显示屏,该智能终端的温度传感器是预先在能终端内部设置,用于检测内部设备的运行温度。
本领域技术人员可以理解,图10中示出的原理框图,仅仅是与本发明方案相关的部分结构的框图,并不构成对本发明方案所应用于其上的智能终端的限定,具体的智能终端以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本发明所提供的各实施例中所使用的对存储器、存储、运营数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双运营数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
综上,本发明公开了一种基于鉴别器森林提高生成对抗网络泛化能力的方法,首先构建由一个生成器和鉴别器森林组成的生成对抗网络模型;生成训练样本,基于所述训练样本对所述鉴别器森林中的每个鉴别器分别进行训练,得到训练过的鉴别器森林,再据所述鉴别器的损失值得到训练梯度,根据所述训练梯度更新所述生成器的训练参数,并以最小化的方式对所述生成器进行训练,得到更新的生成器,最后基于所述更新的生成器,重新执行生成训练样本的步骤,以对所述鉴别器进行下一次迭代训练,直到完成迭代次数为止。本发明提出由鉴别器森林和一个生成器组成的生成对抗网络模型,通过鉴别器数量的增加,降低泛化误差上界,提高生成样本的质量和多样性,增强模型的隐私保护能力。并且我们在理论上给出Forest-GAN的泛化误差边界。我们使用边缘函数(margin function)来定义GAN的泛化误差度量,从理论上证明了Forest-GAN的泛化误差上界,分析了泛化误差边界与鉴别器的泛化能力和鉴别器之间的关联性逐渐的关联性。由于不能直接展示GAN的泛化能力,为此我们在密度估计、模型覆盖和个性化攻击这三个任务上间接展示Forest-GAN的泛化能力以及它的数据隐私保护能力。并且在虚拟数据和真实数据上,通过实验结果,进一步佐证Forest-GAN能够通过提高生成样本的多样性和质量来降低泛化误差,通过抵抗个性化攻击(membership inference attack,MIA)来展示隐私保护能力。
最后应说明的是:以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (10)
1.一种基于鉴别器森林提高生成对抗网络泛化能力的方法,其特征在于,所述方法包括:
构建由一个生成器和鉴别器森林组成的生成对抗网络模型;
生成训练样本,基于所述训练样本对所述鉴别器森林中的每个鉴别器分别进行训练,得到训练过的鉴别器森林;
通过所述生成器得到第一生成样本,将所述第一生成样本输入到所述训练过的鉴别器森林中的每个鉴别器中,得到所述鉴别器的损失值;
根据所述鉴别器的损失值得到训练梯度,根据所述训练梯度更新所述生成器的训练参数,并以最小化的方式对所述生成器进行训练,得到更新的生成器;
基于所述更新的生成器,重新执行所述生成训练样本的步骤,以对所述鉴别器进行下一次迭代训练,直到完成迭代次数为止。
2.根据权利要求1所述的基于鉴别器森林提高生成对抗网络泛化能力的方法,其特征在于,所述构建由一个生成器和鉴别器森林组成的生成对抗网络模型之前,包括:
构建所述生成器的神经网络结构,在卷积层和卷积层之间采用批量归一化进行归一化处理,并采用泄露的修正线性单元进行激活;
为所有鉴别器构建相同的神经网络结构,在每个鉴别器的卷积层和卷积层之间采用批量归一化进行归一化处理,并采用泄露的修正线性单元进行激活。
3.根据权利要求1所述的基于鉴别器森林提高生成对抗网络泛化能力的方法,其特征在于,所述生成训练样本,基于所述训练样本对所述鉴别器森林中的每个鉴别器分别进行训练,得到训练过的鉴别器森林,包括:
获取真实样本,并根据所述真实样本构建原始样本集,通过Bootstrap采样方法进行随机有放回地采样,得到与若干鉴别器一一对应的若干训练数据集;
将随机采样的128维的高斯噪声数据输入所述生成器,得到第二生成样本;
根据所述训练数据集中的真实样本和所述第二生成样本为每个鉴别器分别生成独立的训练样本;
基于所述训练样本,针对每一个鉴别器分别利用Adam优化器,采用交叉熵损失函数以最大化的方式进行训练,得到所述训练过的鉴别器森林。
4.根据权利要求1所述的基于鉴别器森林提高生成对抗网络泛化能力的方法,其特征在于,所述通过所述生成器得到第一生成样本,将所述第一生成样本输入到所述训练过的鉴别器森林中的每个鉴别器中,得到每个鉴别器的损失值,包括:
设定高斯分布,从所述高斯分布上随机采样128维数据作为所述生成器的输入,得到所述第一生成样本;其中,所述第一生成样本带有生成标签;
将所述第一生成样本输入到所述每个鉴别器,采用交叉熵损失函数得到所述鉴别器的损失值。
5.根据权利要求4所述的基于鉴别器森林提高生成对抗网络泛化能力的方法,其特征在于,所述根据所述鉴别器的损失值得到训练梯度,包括:
将全部鉴别器的损失值进行均值汇总,得到均值汇总损失值;
根据所述均值汇总损失值计算所述训练梯度。
6.根据权利要求4所述的基于鉴别器森林提高生成对抗网络泛化能力的方法,其特征在于,所述根据所述鉴别器的损失值得到训练梯度,包括:
将全部鉴别器的损失值进行加权汇总,得到加权汇总损失值;
根据所述加权汇总损失值计算所述训练梯度。
7.根据权利要求1所述的基于鉴别器森林提高生成对抗网络泛化能力的方法,其特征在于,所述方法还包括:
设置所述生成器和所述鉴别器森林在一次迭代训练中的训练次数比例为1:1。
8.一种基于鉴别器森林提高生成对抗网络泛化能力的装置,其特征在于,所述装置包括:
生成对抗网络模型构建模块,用于构建由一个生成器和鉴别器森林组成的生成对抗网络模型;
鉴别器训练模块,用于生成训练样本,基于所述训练样本对所述鉴别器森林中的每个鉴别器分别进行训练,得到训练过的鉴别器森林;
损失值获取模块,用于通过所述生成器得到第一生成样本,将所述第一生成样本输入到所述训练过的鉴别器森林中的每个鉴别器中,得到所述鉴别器的损失值;
生成器训练模块,用于根据所述鉴别器的损失值得到训练梯度,根据所述训练梯度更新所述生成器的训练参数,并以最小化的方式对所述生成器进行训练,得到更新的生成器;
更新迭代模块,用于基于所述更新的生成器,重新执行鉴别器训练模块中的内容,以对所述鉴别器进行下一次迭代训练,直到完成迭代次数为止。
9.一种智能终端,其特征在于,所述智能终端包括存储器、处理器及存储在所述存储器中并可在所述处理器上运行的基于鉴别器森林提高生成对抗网络泛化能力程序,所述处理器执行所述基于鉴别器森林提高生成对抗网络泛化能力程序时,实现如权利要求1-7任一项所述的基于鉴别器森林提高生成对抗网络泛化能力的方法的步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有基于鉴别器森林提高生成对抗网络泛化能力程序,所述基于鉴别器森林提高生成对抗网络泛化能力程序被处理器执行时,实现如权利要求1-7任一项所述的基于鉴别器森林提高生成对抗网络泛化能力的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210994734.0A CN115329938A (zh) | 2022-08-18 | 2022-08-18 | 一种基于鉴别器森林提高生成对抗网络泛化能力的方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210994734.0A CN115329938A (zh) | 2022-08-18 | 2022-08-18 | 一种基于鉴别器森林提高生成对抗网络泛化能力的方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115329938A true CN115329938A (zh) | 2022-11-11 |
Family
ID=83925449
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210994734.0A Pending CN115329938A (zh) | 2022-08-18 | 2022-08-18 | 一种基于鉴别器森林提高生成对抗网络泛化能力的方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115329938A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116681790A (zh) * | 2023-07-18 | 2023-09-01 | 脉得智能科技(无锡)有限公司 | 一种超声造影图像生成模型的训练方法及图像的生成方法 |
-
2022
- 2022-08-18 CN CN202210994734.0A patent/CN115329938A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116681790A (zh) * | 2023-07-18 | 2023-09-01 | 脉得智能科技(无锡)有限公司 | 一种超声造影图像生成模型的训练方法及图像的生成方法 |
CN116681790B (zh) * | 2023-07-18 | 2024-03-22 | 脉得智能科技(无锡)有限公司 | 一种超声造影图像生成模型的训练方法及图像的生成方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11501192B2 (en) | Systems and methods for Bayesian optimization using non-linear mapping of input | |
CN110135510B (zh) | 一种动态领域自适应方法、设备及计算机可读存储介质 | |
CN115115905B (zh) | 基于生成模型的高可迁移性图像对抗样本生成方法 | |
CA2619973A1 (en) | Training convolutional neural networks on graphics processing units | |
CN112115967B (zh) | 一种基于数据保护的图像增量学习方法 | |
US9536206B2 (en) | Method and apparatus for improving resilience in customized program learning network computational environments | |
CN112434213B (zh) | 网络模型的训练方法、信息推送方法及相关装置 | |
CN112699941B (zh) | 植物病害严重程度图像分类方法、装置、设备和存储介质 | |
CN106886793B (zh) | 基于判别信息和流形信息的高光谱图像波段选择方法 | |
CN111488904A (zh) | 基于对抗分布训练的图像分类方法及系统 | |
CN114512191A (zh) | 一种基于迁移成分分析的青霉素浓度预测方法 | |
CN110991621A (zh) | 一种基于通道数搜索卷积神经网络的方法 | |
CN111695624A (zh) | 数据增强策略的更新方法、装置、设备及存储介质 | |
CN115329938A (zh) | 一种基于鉴别器森林提高生成对抗网络泛化能力的方法 | |
CN113935496A (zh) | 一种面向集成模型的鲁棒性提升防御方法 | |
CN114830137A (zh) | 用于生成预测模型的方法和系统 | |
CN113902959A (zh) | 图像识别方法、装置、计算机设备和存储介质 | |
Li | Sequential Design of Experiments to Estimate a Probability of Failure. | |
CN111967499A (zh) | 基于自步学习的数据降维方法 | |
CN111416595A (zh) | 一种基于多核融合的大数据滤波方法 | |
CN115083001B (zh) | 基于图像敏感位置定位的对抗补丁生成方法与装置 | |
CN117454668B (zh) | 零部件失效概率的预测方法、装置、设备和介质 | |
Zheng et al. | Meta Learning for Blind Image Quality Assessment Via Adaptive Sample Re-Weighting | |
Bhatia | Generalized Loss Functions for Generative Adversarial Networks | |
Kang et al. | Efficient Graduated Non-Convexity for Pose Graph Optimization |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |