CN113361566B - 用对抗性学习和判别性学习来迁移生成式对抗网络的方法 - Google Patents

用对抗性学习和判别性学习来迁移生成式对抗网络的方法 Download PDF

Info

Publication number
CN113361566B
CN113361566B CN202110534134.1A CN202110534134A CN113361566B CN 113361566 B CN113361566 B CN 113361566B CN 202110534134 A CN202110534134 A CN 202110534134A CN 113361566 B CN113361566 B CN 113361566B
Authority
CN
China
Prior art keywords
data
training
learning
domain
adt
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110534134.1A
Other languages
English (en)
Other versions
CN113361566A (zh
Inventor
李阳
王宇阳
文敦伟
常佳乐
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Changchun University of Technology
Original Assignee
Changchun University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Changchun University of Technology filed Critical Changchun University of Technology
Priority to CN202110534134.1A priority Critical patent/CN113361566B/zh
Publication of CN113361566A publication Critical patent/CN113361566A/zh
Application granted granted Critical
Publication of CN113361566B publication Critical patent/CN113361566B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)
  • Image Processing (AREA)

Abstract

本发明公开了用对抗性学习和判别性学习来迁移生成式对抗网络的方法,它包括:S1.准备图片数据集;S2.构建预训练GAN模型;S3.通过参数迁移构建ADT‑GAN模型;S4.训练ADT‑GANc。上述ADT‑GAN模型利用迁移学习,在源域图像数据集训练的预训练GAN模型的基础上通过参数传递,初始化生成器和判别器。添加域判别器,通过优化由对抗目标函数和域判别目标函数组成的总目标函数,来驱动生成器生成目标域的图像数据,并避免负迁移。从而提高在小型目标域数据集上的训练性能,减少迭代次数,提高图像生成质量。

Description

用对抗性学习和判别性学习来迁移生成式对抗网络的方法
技术领域
本发明属于深度学习神经网络,具体涉及用对抗性学习和判别性学习来迁移生成式对抗网络的方法。
背景技术
生成式对抗网络(GANs是一种深层的模型,已经引起了广泛的关注,并且在许多领域中使用GAN的需求也在增长。像其他深度神经网络一样,GAN具有很高的计算需求,需要在大型数据集上进行训练,而在小型训练数据集上快速有效地训练 GAN,并生成有效的样本,这成为一个特别重要且具有挑战性的研究问题。
迁移学习它旨在通过转移包含在不同但相关的源域中的知识来提高目标域中目标学习者的表现。迁移学习已与GAN 一起使用,主要侧重于图像到图像的翻译和领域适应。图像到图像转换将图像从一个域转换到另一个域,并且输入和输出都是图像。域自适应的目的是将不同但相似的域(例如源域和目标域)的数据映射到相同的特征空间中,以提高目标域中分类模型的性能。以上两个类别利用了从源领域学到的知识或表示,并且对抗学习的机制可以集成到迁移学习方法中。但是,对抗学习的目的是训练网络,使其无法区分从源域提取的特征和从目标域提取的特征,这与GAN中的对抗学习的目的不 同。GAN 中对抗性学习的目的是训练网络,使其无法在同一域中将真实样本与生成的样本区分开。因此,以上两种方法均不支持在传递在源域中学习到的知识后通过在目标域中输入随机噪声来直接生成样本。
为了解决上述问题,本发明提出了一种新的 GAN 框架:用对抗性学习和判别性学习来迁移生成式对抗网络(ADT-GAN)。ADT-GAN可以支持将用源域的数据集训练的GAN 转移到相关的目标域,以对目标数据集进行进一步的训练,从而训练生成器以生成目标域的样本。首先,使用源域的训练数据集 对GAN模型进行预训练,并通过参数传递来初始 化ADT-GAN的生成器和判别器,然后对目标域的 数据集进行训练。同时,为了避免负迁移问题,采用了额外的域判别器来鼓励生成与目标域而不是源域具有相同分布的样本。
发明内容
本发明的目的在于解决生成对抗网络难以在小数据量的目标域图片数据集进行有效快速地训练的问题,而提出了用对抗性学习和判别性学习来迁移生成式对抗网络的方法。
用对抗性学习和判别性学习来迁移生成式对抗网络的方法,它包括:
S1. 准备图片数据集
1)将图片数据集分割为源域数据集和目标域数据集;
2)将数据集中的图片标准化到相同的分辨率;
S2. 构建预训练GAN模型
预训练GAN 模型为深度卷积生成对抗网络,它包括:生成器G以及判别器D;
S3. 通过参数迁移构建ADT-GAN模型
ADT-GAN模型,它包括生成器G ω ,判别器D θ ,以及域判别器C μ
G 初始化生成器 G 0G 0只用于生成数据,不参与迭代更新,其生成的数据
Figure 100002_DEST_PATH_IMAGE001
,其中
Figure 727989DEST_PATH_IMAGE002
G 0(z)的分布;
用G初始化生成器G ω : zx, 其中生成的数据是G ω(z) ∼ p G (x),ωG ω 的参数,p G (x) 是 G ω(z). 的分布;
用 D 初始化判别器D θ : x → [0, 1],D θ (x) 是 x 来自训练数据集的概率,θD θ 的参数;域判别器 C μ : x → [0,1],其中 C μ (x) 是 x 是来自
Figure 100002_DEST_PATH_IMAGE003
的数据的概率,其中μC μ 的参数;
S4.训练ADT-GAN模型;
步骤S4所述的训练ADT-GAN模型包括:
1) 定义对抗目标函数V adv
Figure 100002_DEST_PATH_IMAGE005
其含义为:输入噪声 z, G ω 可以生成数据 G ω (z)。D θ 用于区分 xG ω (z)。G ω 的目标是最小化对抗目标函数, 而 D θ 的目标是最大化对抗目标函数;
2) 定义域判别目标函数V adv
Figure 100002_DEST_PATH_IMAGE007
其含义为:输入噪声 z, G 0 可以生成数据 G 0 (z),G ω 可以生成数据 G ω (z)。 C μ 用于区分x,为给来自于目标域的G ω (z)和目标域训练集的x高分,给来自于源域的G 0(z)低分。G ω C μ 的目标均为最小化域判别目标函数;
3) 定义总目标函数V adv 为对抗目标函数V adv 和定义域判别目标函数V adv 的加权和,形式为
Figure 100002_DEST_PATH_IMAGE009
其中权重αβ为超参数;
4)训练ADT-GAN
ADT-GAN的训练在每个迭代周期分为四步,依次为:
(a)D θ 的学习;定义损失函数L 1
Figure 100002_DEST_PATH_IMAGE011
为使得D θ 能更好地区分真实数据x和生成数据G ω (z),固定G ω 以及C μ ,即固定参数μω 固定,学习D θ 的参数θ,这一步骤是通过对损失函数L 1 梯度上升来执行的:
(b)D θ 指导G ω 的学习,定义损失函数L 2
Figure DEST_PATH_IMAGE013
为了让G ω 生成的数据被D θ 被认为来自真实数据,在第一步更新的参数θ的基础上,通过对损失函数L 2 梯度下降来更新G ω 的参数ω
(c)C μ 的学习,定义损失函数L 3
Figure DEST_PATH_IMAGE015
为了让C μ 能更好地区分x,来自源域还是目标域,通过对损失函数L 3的梯度下降来更新C μ 的参数μ
(d)C μ 指导G ω 的学习;定义损失函数L 4
Figure DEST_PATH_IMAGE017
为了让G ω 可以生成更接近目标域而不是源域的图像,通过对损失函数L 4 梯度下降,来更新G ω 的参数ω
步骤S2所述的生成器G为: zx以噪声信号zp z (z)为输入,生成数据G(z) ∼p G (x),p G (x) 是生成数据的分布;
判别器D: x → [0, 1],给出真实数据的可能性;
给定的训练数据集中,p data(x)是真实数据x的分布,xp data(x);
定义训练目标函数V(G, D) = 𝔼 xp data(x)定义如下:
Figure DEST_PATH_IMAGE019
预训练GAN 模型中的参数,通过训练G最小化V得到,训练D最大化V得到;
所述的步骤S2,针对MNIST数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0005;
针对CelebA数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0001;
所述的步骤S4的4)针对MNIST数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0005;针对CelebA数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0001;
所述的用对抗性学习和判别性学习来迁移生成式对抗网络的方法,它还包括训练后模型的评估;
所述的模型的评估包括:通过其生成图像与真实图像在Inception v3 图像分类模型中得到网络的中间特征的Fréchet 初始距离FID来度量生成图像与真实图像的相似性;生成图像和真实图像在Inception v3 图像分类模型中得到网络的中间特征可建模成高斯分布,其均值分别为m rm g,协方差矩阵分别为Σr和Σg;描述两个中间特征统计相似性的FID定义为
Figure DEST_PATH_IMAGE021
本发明提供了用对抗性学习和判别性学习来迁移生成式对抗网络的方法,它包括:S1.准备图片数据集;S2.构建预训练GAN模型;S3.通过参数迁移构建ADT-GAN模型;S4.训练ADT-GANc。上述ADT-GAN模型利用迁移学习,在源域图像数据集训练的预训练GAN模型的基础上通过参数传递,初始化生成器和判别器。添加域判别器,通过优化由对抗目标函数和域判别目标函数组成的总目标函数,来驱动生成器生成目标域的图像数据,并避免负迁移。从而提高在小型目标域数据集上的训练性能,减少迭代次数,提高图像生成质量。
附图说明
图1 ADT-GAN模型的结构示意图;
图2 在MNIST上训练DCGAN、初始化DCGAN和ADT-GAN不同迭代中的FID值;
图3 在CelebA上训练DCGAN、初始化DCGAN和ADT-GAN不同迭代中的FID值;
图4 在MNIST-9上由DCGAN、初始化DCGAN和ADT-GAN第800次迭代生成的图像;
图5 在CelebA-M-1.10上由DCGAN、初始化DCGAN和ADT-GAN第2400次迭代生成的图像;
图6在MNIST-9上由初始化DCGAN和ADT-GAN第2000次迭代生成的图像;
图7在MNIST-9-1.2上由初始化DCGAN和ADT-GAN第2400次迭代生成的图像;
图8在MNIST-9-1.5上由初始化DCGAN和ADT-GAN第2000次迭代生成的图像。
具体实施方式
实施例1
一种基于对抗性学习和判别行学习相结合的对抗神经网路模型ADT-GAN及训练方法,它包括:
S1. 准备图片数据集
准备含有较大数据量的源域图片数据集以及较小数据量的目标域图片数据集,并将源域数据集和目标域数据集分别做以下处理:
1)将图片数据集分割为源域数据集和目标域数据集;
2)将数据集中的图片标准化到相同的分辨率;
MNIST是手写数字的数据集,由60000个训练数据和10000个图像的测试数据组成,本发明只采用训练数据。对MNIST手写数据数据集,每个手写数字图像尺寸标准化至 28×28 像素的灰度图像,并置于图像中心。为了检验ADT-GAN 的效果和训练GAN对小训练数据集的影响,我们将MNIST的训练集分为两类,一类是含有数字9的图像的MNIST-9训练集,另一类是不含有数字9的图像的 MNIST-not9训练集。MNIST-not9 训练集作为源域的训练集,MNIST-9 训练集作为目标域的训练集。根据数字9图像的个数,构造MNIST-9-1.2和MNIST-9-1.5分别为MNIST-9的1/2和1/5,来评估小数据量的目标域训练集上的训练性能。训练数据集的大小分配见表1
Figure DEST_PATH_IMAGE023
CelebA 是一个名人面部图像的数据集。它包含 10177 个名人身份的 202599 张人脸图像,包括118165 张女性人脸图像和84434 张男性人脸图像。对CelebA数据集,每个人脸图像尺寸标准化至 64×64 像素。然后,将女性图像放入训练集 CelebA-F,作为本实验源域的训练集,将男性图像放入训练集 CelebA-M,作为目标域训练集,再根据男性图像数分别为 CelebA-M 的 1/10 和 1/50 构造 CelebA-M-1.10 和 CelebA-M-1.50,来评估小数据量的目标域训练集上的训练性能。有关每个训练集的具体大小见表2。
Figure DEST_PATH_IMAGE025
S2. 构建预训练GAN模型
预训练GAN 模型为深度卷积生成对抗网路(Deep Convolutional GenerativeAdversarial Network, DCGAN),包括生成器G以及判别器D。
给定的训练数据集中,p data(x)是真实数据x的分布,xp data(x)。生成器G: zx以噪声信号zp z (z)为输入,生成数据G(z) ∼ p G (x),p G (x) 是生成数据的分布。判别器D: x → [0, 1],给出真实数据的可能性。判别器接收来自两个来源的数据作为输入:来自训练数据集的真实数据和来自生成器的生成数据。针对MNIST数据集的判别器及生成器的网络结构及超参数见表3。
Figure DEST_PATH_IMAGE027
针对CelebA数据集的判别器及生成器的网络结构及超参数见表4。
Figure DEST_PATH_IMAGE029
定义训练目标函数V(G, D) = 𝔼 xp data(x)定义如下:
Figure DEST_PATH_IMAGE031
预训练GAN 模型中的参数,通过训练G最小化V得到,训练D最大化V得到。
针对MNIST数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0005。
针对CelebA数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0001。
S3. 通过参数迁移构建ADT-GAN模型
G 初始化生成器 G 0G 0只用于生成数据,不参与迭代更新,其生成的数据
Figure 71508DEST_PATH_IMAGE001
,其中
Figure 935559DEST_PATH_IMAGE002
G 0(z)的分布。
用G初始化生成器G ω : zx, 其中生成的数据是G ω(z) ∼ p G (x),ωG ω 的参数,p G (x) 是 G ω(z). 的分布。
用 D 初始化判别器D θ : x → [0, 1],D θ (x) 是 x 来自训练数据集的概率,θD θ 的参数。域判别器 C μ : x → [0,1],其中 C μ (x) 是 x 是来自
Figure 218772DEST_PATH_IMAGE003
的数据的概率,其中μC μ 的参数。
ADT-GAN模型,如图1所示,它包括生成器G ω ,判别器D θ ,以及域判别器C μ 。对于MNIST数据集的生成器G ω ,判别器D θ ,以及域判别器C μ 的网络结构及超参数见表5。对于MNIST数据集的生成器G ω ,判别器D θ ,以及域判别器C μ 的网络结构及超参数见表6。
Figure DEST_PATH_IMAGE033
Figure DEST_PATH_IMAGE035
S4.训练ADT-GAN模型
1) 定义对抗目标函数V adv
Figure DEST_PATH_IMAGE037
其含义为:输入噪声 z, G ω 可以生成数据 G ω (z)。D θ 用于区分 xG ω (z)。G ω 的目标是最小化对抗目标函数, 而 D θ 的目标是最大化对抗目标函数。
2) 定义域判别目标函数V adv
Figure DEST_PATH_IMAGE039
其含义为:输入噪声 z, G 0 可以生成数据 G 0 (z),G ω 可以生成数据 G ω (z)。 C μ 用于区分x,为给来自于目标域的G ω (z)和目标域训练集的x高分,给来自于源域的G 0(z)低分。G ω C μ 的目标均为最小化域判别目标函数。
3) 定义总目标函数V adv 为对抗目标函数V adv 和定义域判别目标函数V adv 的加权和,形式为
Figure DEST_PATH_IMAGE041
其中权重αβ为超参数(hyperpaarmeter),控制域判别目标函数相对于对抗目标函数的重要性。对于MNIST数据集为α = 1.0,β= 2.0。对于CelebA数据集,α = 1.0,β= 0.2。
4)训练ADT-GAN
ADT-GAN的训练在每个迭代周期分为四步,依次为:
(a)D θ 的学习。定义损失函数L 1
Figure DEST_PATH_IMAGE043
为使得D θ 能更好地区分真实数据x和生成数据G ω (z),固定G ω 以及C μ ,即固定参数μω 固定,学习D θ 的参数θ。这一步骤是通过对损失函数L 1 梯度上升来执行的:
(b)D θ 指导G ω 的学习。定义损失函数L 2
Figure DEST_PATH_IMAGE045
为了让G ω 生成的数据被D θ 被认为来自真实数据,在第一步更新的参数θ的基础上,通过对损失函数L 2 梯度下降来更新G ω 的参数ω
(c)C μ 的学习。定义损失函数L 3
Figure DEST_PATH_IMAGE047
为了让C μ 能更好地区分x,来自源域还是目标域,通过对损失函数L 3的梯度下降来更新C μ 的参数μ
(d)C μ 指导G ω 的学习。定义损失函数L 4
Figure DEST_PATH_IMAGE049
为了让G ω 可以生成更接近目标域而不是源域的图像,通过对损失函数L 4 梯度下降,来更新G ω 的参数ω
以上四步每步均采用如下优化算法及优化参数:
针对MNIST数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0005。
针对CelebA数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0001。
实施例2 训练后模型的评估
GAN(包括ADT-GAN)的评估,可以通过其生成图像与真实图像在Inception v3 图像分类模型中得到网络的中间特征的Fréchet 初始距离(FID)来度量生成图像与真实图像的相似性。生成图像和真实图像在Inception v3 图像分类模型中得到网络的中间特征可建模成高斯分布,其均值分别为m rm g,协方差矩阵分别为Σr和Σg。描述两个中间特征统计相似性的FID定义为
Figure DEST_PATH_IMAGE051
FID越小,代表两组图像越相似;FID越大,代表两组图像差别越大。
利用DCGAN、初始化DCGAN和ADTGAN在同一次迭代中的FID值来判断初始化DCGAN和ADT-GAN中是否存在负迁移。更具体地说,如果在同一迭代中初始化的DCGAN或ADTGAN的FID值高于DCGAN的FID值,则在初始化的DCGAN或ADT-GAN中存在负转移, 验证ADT-GAN中的域分类是否有效。
在MNIST和CelebA上DCGAN、初始化DCGAN和ADT-GAN的FID值,如图2和图3所示。在图 2(a)中,ADT-GAN和初始DCGAN分别需要800次迭代和1000次迭代,将FID值降低到36以下,而 DCGAN在2200 次迭代中最小FID值为43.28。另外,在图2(b)中,ADT-GAN和初始化的DCGAN需要600次迭代和800次迭代,将FID值降低到48以下,而 DCGAN在2400次迭代中最小FID值为56.98。同样在图2(c)中,ADT-GAN和初始化的DCGAN需要1000次迭代和400次迭代,将FID值降低到80以下,而 DCGAN在1800次迭代中最小FID值为83.96。MNIST上由DCGAN、初始化DCGAN和ADT-GAN生成的图像如图4所示。通过对图3的观察,可以发现在CelebA上的实验结果与在MNIST上的实验结果是一致的。ADT-GAN和初始化的DCGAN都可以在较少的迭代次数内将FID值降低到低于DCGAN的最小FID值。CelebA上由DCGAN、初始化DCGAN和ADT-GAN生成的图像如图5所示。总之,在MNIST和CelebA上,ADT-GAN和初始化的DCGAN可以在较少的迭代次数内获得比DCGAN更好地性能,且ADT-GAN的性能还要好于初始化的DCGAN。
进一步观察图2和图3,分析ADT-GAN和初始化的DCGAN在MNIST和CelebA上是否有负迁移。在图2中,初始化的DCGAN在MNIST上具有比相同迭代次数下的DCGAN高的FID值,这意味着初始化的DCGAN在MNIST上具有负迁移。在图2(a)中,初始化的DCGAN具有从2000到2600次迭代的负迁移。同样的效果在图2(b)和图2(c)中也可以看到,分别经过2000次迭代和1400到4800次迭代。与此相反,ADT-GAN没有负迁移的迹象。通过图 6、图 7、图 8,可以发现初始化后的DCGAN产生了许多源域图像(画框),如数字 0、1、7、6、8 等,而ADT-GAN没有,这验证ADT-GAN的域鉴别器的有效性。在图3中,相同迭代次数下ADT-GAN和初始化 DCGAN的FID值低于DCGAN,说明 ADT-GAN和初始化DCGAN在CelebA上没有负迁移。可见ADT-GAN无论在MNIST上还是CelebA上都不存在负迁移。

Claims (5)

1. 用对抗性学习和判别性学习来迁移生成式对抗网络的方法,它包括:
S1. 准备图片数据集
1)将图片数据集分割为源域数据集和目标域数据集;
2)将数据集中的图片标准化到相同的分辨率;
S2. 构建预训练GAN模型
预训练GAN 模型为深度卷积生成对抗网络,它包括:生成器G以及判别器D;
S3. 通过参数迁移构建ADT-GAN模型
ADT-GAN模型,它包括生成器G ω ,判别器D θ ,以及域判别器C μ
G 初始化生成器 G 0G 0只用于生成数据,不参与迭代更新,其生成的数据
Figure DEST_PATH_IMAGE001
,其中
Figure DEST_PATH_IMAGE002
G 0(z)的分布;
用G初始化生成器G ω : zx, 其中生成的图像数据是G ω(z) ∼ p G (x),ωG ω 的参数,p G (x) 是 G ω(z) 的分布;
用 D 初始化判别器D θ : x → [0, 1],D θ (x) 是 x 来自训练数据集的概率,θD θ 的参数;域判别器 C μ : x → [0,1],其中 C μ (x) 是 x 是来自
Figure DEST_PATH_IMAGE003
的数据的概率,其中μC μ 的参数;
S4.训练ADT-GAN模型。
步骤S4所述的训练ADT-GAN模型包括:
1) 定义对抗目标函数V adv
Figure DEST_PATH_IMAGE004
其含义为:输入噪声 z, G ω 可以生成数据 G ω (z),D θ 用于区分 xG ω (z);G ω 的目标是最小化对抗目标函数, 而 D θ 的目标是最大化对抗目标函数;
2) 定义域判别目标函数V adv
Figure DEST_PATH_IMAGE005
其含义为:输入噪声 z, G 0 可以生成数据 G 0 (z),G ω 可以生成数据 G ω (z), C μ 用于区分x,为给来自于目标域的G ω (z)和目标域训练集的x高分,给来自于源域的G 0(z)低分,G ω C μ 的目标均为最小化域判别目标函数;
3) 定义总目标函数V adv 为对抗目标函数V adv 和定义域判别目标函数V adv 的加权和,形式为
Figure DEST_PATH_IMAGE006
其中权重αβ为超参数;
4)训练ADT-GAN
ADT-GAN的训练在每个迭代周期分为四步,依次为:
(a)D θ 的学习,定义损失函数L 1
Figure DEST_PATH_IMAGE007
为使得D θ 能更好地区分真实图像数据x和生成图像数据G ω (z),固定G ω 以及C μ ,即固定参数μω 固定,学习D θ 的参数θ,这一步骤是通过对损失函数L 1 梯度上升来执行的:
(b)D θ 指导G ω 的学习,定义损失函数L 2
Figure DEST_PATH_IMAGE008
为了让G ω 生成的数据被D θ 被认为来自真实数据,在第一步更新的参数θ的基础上,通过对损失函数L 2 梯度下降来更新G ω 的参数ω
(c)C μ 的学习,定义损失函数L 3
Figure DEST_PATH_IMAGE009
为了让C μ 能更好地区分x,来自源域还是目标域,通过对损失函数L 3的梯度下降来更新C μ 的参数μ
(d)C μ 指导G ω 的学习;定义损失函数L 4
Figure DEST_PATH_IMAGE010
为了让G ω 可以生成更接近目标域而不是源域的图像,通过对损失函数L 4 梯度下降,来更新G ω 的参数ω
2.根据权利要求1所述的用对抗性学习和判别性学习来迁移生成式对抗网络的方法,其特征在于:步骤S2所述的生成器G为: zx以噪声信号zp z (z)为输入,生成数据G(z) ∼ p G (x),p G (x) 是生成数据的分布;
判别器D: x → [0, 1],给出真实数据的可能性;
给定的训练数据集中,p data(x)是真实数据x的分布,xp data(x);
定义训练目标函数V(G, D) = 𝔼 xp data(x)定义如下:
Figure DEST_PATH_IMAGE011
预训练GAN 模型中的参数,通过训练G最小化V得到,训练D最大化V得到。
3.根据权利要求2所述的用对抗性学习和判别性学习来迁移生成式对抗网络的方法,其特征在于:所述的步骤S2,针对MNIST数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0005;
针对CelebA数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0001。
4.根据权利要求3所述的用对抗性学习和判别性学习来迁移生成式对抗网络的方法,其特征在于:所述的步骤S4的4)针对MNIST数据集,优化算法为Adam,优化算法参数β1 =0.5,学习速率为0.0005;针对CelebA数据集,优化算法为Adam,优化算法参数β1 = 0.5,学习速率为0.0001。
5.根据权利要求1、2、3或4 所述的用对抗性学习和判别性学习来迁移生成式对抗网络的方法,其特征在于:它还包括训练后模型的评估;
所述的模型的评估包括:通过其生成图像与真实图像在Inception v3 图像分类模型中得到网络的中间特征的Fréchet 初始距离FID来度量生成图像与真实图像的相似性;生成图像和真实图像在Inception v3 图像分类模型中得到网络的中间特征可建模成高斯分布,其均值分别为m rm g,协方差矩阵分别为Σr和Σg;描述两个中间特征统计相似性的FID定义为
Figure DEST_PATH_IMAGE012
CN202110534134.1A 2021-05-17 2021-05-17 用对抗性学习和判别性学习来迁移生成式对抗网络的方法 Active CN113361566B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110534134.1A CN113361566B (zh) 2021-05-17 2021-05-17 用对抗性学习和判别性学习来迁移生成式对抗网络的方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110534134.1A CN113361566B (zh) 2021-05-17 2021-05-17 用对抗性学习和判别性学习来迁移生成式对抗网络的方法

Publications (2)

Publication Number Publication Date
CN113361566A CN113361566A (zh) 2021-09-07
CN113361566B true CN113361566B (zh) 2022-11-15

Family

ID=77526781

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110534134.1A Active CN113361566B (zh) 2021-05-17 2021-05-17 用对抗性学习和判别性学习来迁移生成式对抗网络的方法

Country Status (1)

Country Link
CN (1) CN113361566B (zh)

Families Citing this family (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114399829B (zh) * 2022-03-25 2022-07-05 浙江壹体科技有限公司 基于生成式对抗网络的姿态迁移方法、电子设备及介质
CN115272687B (zh) * 2022-07-11 2023-05-05 哈尔滨工业大学 单样本自适应域生成器迁移方法
CN115936090A (zh) * 2022-11-25 2023-04-07 北京百度网讯科技有限公司 模型训练方法、设备和存储介质
CN116128047B (zh) * 2022-12-08 2023-11-14 西南民族大学 一种基于对抗网络的迁移学习方法
CN116187206B (zh) * 2023-04-25 2023-07-07 山东省科学院海洋仪器仪表研究所 一种基于生成对抗网络的cod光谱数据迁移方法
CN116737793A (zh) * 2023-05-29 2023-09-12 南方电网能源发展研究院有限责任公司 碳排放流生成方法、模型训练方法、装置和计算机设备

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109753992A (zh) * 2018-12-10 2019-05-14 南京师范大学 基于条件生成对抗网络的无监督域适应图像分类方法
CN111028146A (zh) * 2019-11-06 2020-04-17 武汉理工大学 基于双判别器的生成对抗网络的图像超分辨率方法
CN111242157A (zh) * 2019-11-22 2020-06-05 北京理工大学 联合深度注意力特征和条件对抗的无监督域自适应方法
CN111738315A (zh) * 2020-06-10 2020-10-02 西安电子科技大学 基于对抗融合多源迁移学习的图像分类方法
CN112801895A (zh) * 2021-01-15 2021-05-14 南京邮电大学 一种基于二阶段注意力机制gan网络图像修复算法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11105942B2 (en) * 2018-03-27 2021-08-31 Schlumberger Technology Corporation Generative adversarial network seismic data processor

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109753992A (zh) * 2018-12-10 2019-05-14 南京师范大学 基于条件生成对抗网络的无监督域适应图像分类方法
CN111028146A (zh) * 2019-11-06 2020-04-17 武汉理工大学 基于双判别器的生成对抗网络的图像超分辨率方法
CN111242157A (zh) * 2019-11-22 2020-06-05 北京理工大学 联合深度注意力特征和条件对抗的无监督域自适应方法
CN111738315A (zh) * 2020-06-10 2020-10-02 西安电子科技大学 基于对抗融合多源迁移学习的图像分类方法
CN112801895A (zh) * 2021-01-15 2021-05-14 南京邮电大学 一种基于二阶段注意力机制gan网络图像修复算法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
一种基于GAN和自适应迁移学习的样本生成方法;周立君等;《应用光学》;20200115(第01期);第120-126页 *
基于改进的CycleGAN模型非配对的图像到图像转换;何剑华等;《玉林师范学院学报》;20180401(第02期);第122-126页 *
生成式对抗网络GAN的研究进展与展望;王坤峰等;《自动化学报》;20170315(第03期);第321-332页 *

Also Published As

Publication number Publication date
CN113361566A (zh) 2021-09-07

Similar Documents

Publication Publication Date Title
CN113361566B (zh) 用对抗性学习和判别性学习来迁移生成式对抗网络的方法
CN110414377B (zh) 一种基于尺度注意力网络的遥感图像场景分类方法
CN112446423B (zh) 一种基于迁移学习的快速混合高阶注意力域对抗网络的方法
CN109345508B (zh) 一种基于两阶段神经网络的骨龄评价方法
CN108648191B (zh) 基于贝叶斯宽度残差神经网络的害虫图像识别方法
CN114841257B (zh) 一种基于自监督对比约束下的小样本目标检测方法
CN114283287B (zh) 基于自训练噪声标签纠正的鲁棒领域自适应图像学习方法
CN112115967B (zh) 一种基于数据保护的图像增量学习方法
CN111401156B (zh) 基于Gabor卷积神经网络的图像识别方法
Zhou et al. When semi-supervised learning meets transfer learning: Training strategies, models and datasets
CN114038055A (zh) 一种基于对比学习和生成对抗网络的图像生成方法
CN112364791A (zh) 一种基于生成对抗网络的行人重识别方法和系统
CN114004333A (zh) 一种基于多假类生成对抗网络的过采样方法
CN113256508A (zh) 一种改进的小波变换与卷积神经网络图像去噪声的方法
CN114863348A (zh) 基于自监督的视频目标分割方法
CN114722892A (zh) 基于机器学习的持续学习方法及装置
CN112614070A (zh) 一种基于DefogNet的单幅图像去雾方法
Teng et al. BiSeNet-oriented context attention model for image semantic segmentation
Li et al. Distilling ensemble of explanations for weakly-supervised pre-training of image segmentation models
Quiroga et al. A study of convolutional architectures for handshape recognition applied to sign language
Zhang et al. Long-tailed classification with gradual balanced loss and adaptive feature generation
CN108428226B (zh) 一种基于ica稀疏表示与som的失真图像质量评价方法
Linneberg et al. Towards semen quality assessment using neural networks
Yow et al. Iris recognition system (IRS) using deep learning technique
CN114783039A (zh) 一种3d人体模型驱动的运动迁移方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant