CN111062468A - 生成网络的训练方法和系统、以及图像生成方法及设备 - Google Patents

生成网络的训练方法和系统、以及图像生成方法及设备 Download PDF

Info

Publication number
CN111062468A
CN111062468A CN202010152216.5A CN202010152216A CN111062468A CN 111062468 A CN111062468 A CN 111062468A CN 202010152216 A CN202010152216 A CN 202010152216A CN 111062468 A CN111062468 A CN 111062468A
Authority
CN
China
Prior art keywords
network
parameter
generating
sample
generated
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010152216.5A
Other languages
English (en)
Other versions
CN111062468B (zh
Inventor
李建
郭建波
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tuling Artificial Intelligence Institute Nanjing Co ltd
Tsinghua University
Original Assignee
Tuling Artificial Intelligence Institute Nanjing Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tuling Artificial Intelligence Institute Nanjing Co ltd filed Critical Tuling Artificial Intelligence Institute Nanjing Co ltd
Priority to CN202010152216.5A priority Critical patent/CN111062468B/zh
Publication of CN111062468A publication Critical patent/CN111062468A/zh
Application granted granted Critical
Publication of CN111062468B publication Critical patent/CN111062468B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y04INFORMATION OR COMMUNICATION TECHNOLOGIES HAVING AN IMPACT ON OTHER TECHNOLOGY AREAS
    • Y04SSYSTEMS INTEGRATING TECHNOLOGIES RELATED TO POWER NETWORK OPERATION, COMMUNICATION OR INFORMATION TECHNOLOGIES FOR IMPROVING THE ELECTRICAL POWER GENERATION, TRANSMISSION, DISTRIBUTION, MANAGEMENT OR USAGE, i.e. SMART GRIDS
    • Y04S10/00Systems supporting electrical power generation, transmission or distribution
    • Y04S10/50Systems or methods supporting the power network operation or management, involving a certain degree of interaction with the load-side end user applications

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本申请公开一种生成网络的训练方法和系统、以及图像生成方法及设备。其中,生成网络的训练方法包括以下步骤:获取真实样本和生成样本,所述生成样本是由使用第一参数的生成网络所生成的;利用所述真实样本和所述生成样本对映射网络的第二参数进行训练,直至所训练的映射网络使得所述真实样本和所述生成样本之间的特征距离符合最大化条件;利用训练后的映射网络对所述生成网络的第一参数进行训练,直至训练后的第二参数的条件下的所述真实样本和训练后的生成网络输出的生成样本之间的特征距离符合最小化条件。本申请在大大降低了计算量和节约内存的同时提高了生成网络对高维数据的处理能力以及生成高维图像的性能。

Description

生成网络的训练方法和系统、以及图像生成方法及设备
技术领域
本申请涉及图像处理技术领域,尤其涉及一种生成网络的训练方法和系统、以及图像生成方法及设备。
背景技术
生成网络作为一种人工神经网络,可以根据具体任务通过模型训练而实现将输入的数据样本生成文字、图像、视频等。因此,生成网络被广泛的应用于图像的生成、修复、转换以及重构,文字与图片之间的转换,视频预测等多种场合。
由于近些年来人工智能技术的愈发成熟,各种生成网络的学习/训练方法也不断涌现,这也使得训练出的生成网络在生成图像的质量方面参差不齐。例如,采用有监督式学习来训练生成网络,但这种方法的主要瓶颈之一是难以获得足够的标签样本来学习数据特征以使得生成网络能够捕获数据结构;再如,无监督式学习可以通过自己从错误中进行学习,并降低未来出错的概率以此达到降低对标签样本的需求,但是,这种学习方式往往准确率很低。
是故,如何去训练生成网络以使得生成网络能够高效的输出以假乱真的图像是本领域从业者亟待解决的问题。
发明内容
鉴于以上所述相关技术的缺点,本申请的目的在于提供一种生成网络的训练方法和系统、以及图像生成方法及设备。
为实现上述目的及其他相关目的,本申请第一方面公开一种生成网络的训练方法,其特征在于,包括以下步骤:获取真实样本和生成样本,所述生成样本是由使用第一参数的生成网络所生成的;利用所述真实样本和所述生成样本对映射网络的第二参数进行训练,直至所训练的映射网络使得所述真实样本和所述生成样本之间的特征距离符合最大化条件;利用训练后的映射网络对所述生成网络的第一参数进行训练,直至训练后的第二参数的条件下的所述真实样本和训练后的生成网络输出的生成样本之间的特征距离符合最小化条件;重复上述过程直至经训练得到的第一参数满足预设收敛条件。
在本申请第一方面的某些实施例中,所述映射网络用于将所述真实样本和生成样本映射到特征表示空间内。
在本申请第一方面的某些实施例中,所述映射网络的映射包括:降维操作。
在本申请第一方面的某些实施例中,所述特征距离为经所述映射网络映射到特征表示空间内的真实样本的特征和生成样本的特征之间的最大均值差异。
在本申请第一方面的某些实施例中,所述利用所述真实样本和所述生成样本对映射网络的第二参数进行训练包括:计算经所述映射网络得到的真实样本和生成样本之间的特征距离;基于所述特征距离确定所述第二参数的偏移量,以用于更新所述第二参数;利用更新后的第二参数重复上述步骤,直至所得到的所述特征距离符合所述最大化条件。
在本申请第一方面的某些实施例中,所述第二参数的偏移量包括:基于所述特征距离在第二参数上的梯度而得到的偏移量。
在本申请第一方面的某些实施例中,所述利用训练后的映射网络对所述生成网络的第一参数进行训练的步骤包括:在训练后的第二参数的条件下,基于所述真实样本和每次更新第一参数后生成网络输出的生成样本之间的特征距离而确定的所述第一参数的偏移量,更新所述生成网络的第一参数,直至所述真实样本和训练后的生成网络输出的生成样本之间的特征距离符合所述最小化条件。
在本申请第一方面的某些实施例中,所述第一参数的的偏移量包括:每次更新第一参数后生成网络输出的生成样本和所述真实样本之间的特征距离在第一参数上的梯度而得到的偏移量。
在本申请第一方面的某些实施例中,所述训练方法还包括获取一随机变量的步骤,所述生成样本即为所述生成网络基于所述随机变量生成的。
在本申请第一方面的某些实施例中,所述随机变量遵循于一正态分布或一均匀分布。
在本申请第一方面的某些实施例中,所述真实样本为参考图像,所述生成样本为生成网络输出的生成图像。
本申请的第二方面公开一种生成网络的训练系统,其特征在于,包括:映射网络模块,具有第二参数,用于获取真实样本和生成样本,并基于第二参数输出映射后的真实样本和生成样本;生成网络模块,具有第一参数,用于基于第一参数输出生成样本;训练模块,用于基于第一参数对所述映射网络模块的第二参数训练,直至所训练的映射网络模块使得所述真实样本和使用第一参数的生成网络输出的生成样本之间的特征距离符合最大化条件,以及用于基于训练后的映射网络模块对所述生成网络模块的第一参数训练,直至训练后的第二参数的条件下的所述真实样本和训练后的生成网络输出的生成样本之间的特征距离符合最小化条件;所述训练模块还用于重复上述训练过程,直至经训练的第一参数满足预设收敛条件。
在本申请第二方面的某些实施例中,所述映射网络模块包括:特征表示单元,用于将所述真实样本和生成样本映射到特征表示空间内。
在本申请第二方面的某些实施例中,所述特征表示单元将所述真实样本和生成样本映射到特征表示空间内包括:对真实样本的特征和生成样本的特征进行降维处理。
在本申请第二方面的某些实施例中,所述特征距离为经所述映射网络模块映射到特征表示空间内的真实样本的特征和生成样本的特征之间的最大均值差异。
在本申请第二方面的某些实施例中,所述训练模块还包括用于计算所述特征距离的计算单元。
在本申请第二方面的某些实施例中,所述训练模块还包括:第二更新单元,用于基于所述真实样本和所述生成样本之间的特征距离而确定的所述第二参数的偏移量,更新所述第二参数;第二判别单元,用于在判断利用更新后的第二参数所得到的特征距离符合最大化条件时停止第二更新单元的更新。
在本申请第二方面的某些实施例中,所述第二参数的偏移量包括:基于所述真实样本和所述生成样本之间的特征距离在第二参数上的梯度变化而得到的偏移量。
在本申请第二方面的某些实施例中,所述训练模块还包括:第一更新单元,用于在训练后的第二参数的条件下,基于所述真实样本和每次更新第一参数后生成网络输出的生成样本之间的特征距离而确定的所述第一参数的偏移量,更新所述生成网络的第一参数;第一判别单元,用于在判断基于每次更新后的第一参数而输出的生成样本与所述真实样本之间的特征距离符合最小化条件时停止第一更新单元的更新。
在本申请第二方面的某些实施例中,所述第一参数的的偏移量包括:基于所述真实样本和每次更新第一参数后生成网络输出的生成样本之间的特征距离在第一参数上的梯度变化而得到的偏移量。
在本申请第二方面的某些实施例中,所述生成网络模块还用于获取一随机变量,所述生成样本即为所述生成网络模块基于所述随机变量生成的。
在本申请第二方面的某些实施例中,所述随机变量遵循于一正态分布或一均匀分布。
在本申请第二方面的某些实施例中,所述真实样本为参考图像,所述生成样本为所述生成网络生成的图像。
本申请的第三方面公开一种图像生成方法,包括以下步骤:获取一原始输入图像;利用本申请第一方面公开的任一实施例所述的生成网络训练方法获得的生成网络对所述原始输入图像进行处理以输出一生成图像。
本申请的第四方面公开一种图像生成设备,包括:图像采集装置,用于获取待处理的原始输入图像;存储装置,用于存储计算机程序;处理装置,通信连接所述图像采集装置及存储装置,用于运行所述计算机程序来执行如本申请第三方面公开的图像生成方法。
本申请的第五方面公开一种图像生成客户端,装载于一智能设备;所述客户端包括:输入模块,用于接收到生成指令时,调用所述智能设备的图像采集装置获取待处理的原始输入图像;处理模块,用于执行如本申请第三方面公开的图像生成方法,以获得生成图像。
在本申请第五方面的某些实施例中,所述客户端还包括一显示模块,用于显示所述生成图像。
本申请的第六方面公开一种计算机可读存储介质,存储有至少一计算机程序,所述计算机程序被用于执行实现本申请第一方面公开的任一实施例所述的生成网络的训练方法或执行实现本申请第三方面公开的任一实施例所述的图像生成的方法。
综上所述,本申请提供的生成网络的训练方法、生成网络的训练系统、图像生成方法、图像生成设备、图像生成客户端及计算机可读存储介质,通过引入映射网络与生成网络进行对抗训练,同时提高二者的鲁棒性和准确性,进而有效的提高图像生成的鲁棒性和质量,并利用映射网络的映射功能将高维数据映射到低维数据,以低维数据来度量真实样本和生成样本之间的距离,在大大降低了计算量和节约内存的同时提高了生成网络对高维数据的处理能力以及生成高维图像的性能。并且本申请中采用最大均值差异来评估真实样本和生成样本之间的差异,评估效果更优良,从而对生成网络的训练更精准。
附图说明
本申请所涉及的发明的具体特征如所附权利要求书所显示。通过参考下文中详细描述的示例性实施方式和附图能够更好地理解本申请所涉及发明的特点和优势。对附图简要说明书如下:
图1显示为本申请在一实施例中的生成对抗网络的模型结构框图。
图2显示为本申请在一实施例中的生成网络训练方法的模型框架示意图。
图3显示为本申请在一实施例中的生成网络的训练方法的流程图。
图4显示为本申请在一实施例中的对映射网络的第二参数进行训练的流程图。
图5显示为本申请在一实施例中的训练系统的结构框图。
图6显示为本申请在一实施例中的图像生成方法的流程图。
图7显示为本申请在一实施例中的图像生成设备结构示意图。
图8显示为本申请在一实施例中的图像生成客户端结构示意图。
具体实施方式
以下由特定的具体实施例说明本申请的实施方式,熟悉此技术的人士可由本说明书所揭露的内容轻易地了解本申请的其他优点及功效。
在下述描述中,参考附图,附图描述了本申请的若干实施例。应当理解,还可使用其他实施例,并且可以在不背离本公开的精神和范围的情况下进行机械组成、结构、电气以及操作上的改变。下面的详细描述不应该被认为是限制性的,并且本申请的实施例的范围仅由公布的专利的权利要求书所限定。这里使用的术语仅是为了描述特定实施例,而并非旨在限制本申请。空间相关的术语,例如“上”、“下”、“左”、“右”、“下面”、“下方”、“下部”、“上方”、“上部”等,可在文中使用以便于说明图中所示的一个元件或特征与另一元件或特征的关系。
虽然在一些实例中术语第一、第二等在本文中用来描述各种元件或参数,但是这些元件或参数不应当被这些术语限制。这些术语仅用来将一个元件或参数与另一个元件或参数进行区分。例如,第一参数可以被称作第二参数,并且类似地,第二参数可以被称作第一参数,而不脱离各种所描述的实施例的范围。第一参数和第二参数均是在描述一个参数,但是除非上下文以其他方式明确指出,否则它们不是同一个参数。相似的情况还包括第一更新单元与第二更新单元,或者第一判别单元与第二判别单元。
再者,如同在本文中所使用的,单数形式“一”、“一个”和“该”旨在也包括复数形式,除非上下文中有相反的指示。应当进一步理解,术语“包含”、“包括”表明存在所述的特征、步骤、操作、元件、组件、项目、种类、和/或组,但不排除一个或多个其他特征、步骤、操作、元件、组件、项目、种类、和/或组的存在、出现或添加。此处使用的术语“或”和“和/或”被解释为包括性的,或意味着任一个或任何组合。因此,“A、B或C”或者“A、B和/或C”意味着“以下任一个:A;B;C;A和B;A和C;B和C;A、B和C”。仅当元件、功能、步骤或操作的组合在某些方式下内在地互相排斥时,才会出现该定义的例外。
如背景技术里介绍所言,无监督式的训练方式大都准确率低,而通过采用二人博弈思想的生成式对抗网络作为深度学习模型能大大提升无监督式学习的性能。生成式对抗网络在网络结构上除了生成网络外,还包含一个与生成网络对抗的判别网络。生成网络作为一生成器,用于生成类似于真实样本的随机样本,并将其作为生成样本,判别网络作为一判别器,用于分辨接收到的数据是来自于真实样本还是生成样本。生成网络和对抗网络相互对抗,在对抗中生成网络不断更新其生成样本的能力以产生能够以假乱真(也即,对抗网络分辨不出)的生成样本,而判别网络也随着生成网络的更新而不断更新其分辨真假样本的能力,从而能够辨别出来自于生成网络的生成样本。以下结合图1对生成式对抗网络的工作原理进行说明。
请参阅图1,显示为本申请在一实施例中的生成对抗网络的模型结构框图,这里以生成网络用于生成图片为例进行说明,如图所示,生成式对抗网络包括生成网络G和判别网络D。其中,生成网络G是一个生成图片的网络,它接收一个随机噪声z,通过该噪声z生成的图片记为G(z)。判别网络D用来判别一张图片是不是“真实的”,也即该图片是不是真实样本xx代表一张图片),以判别网络D的输入参数为x来看,其输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表是来自于生成网络生成的图片G(z)。在训练过程中,生成网络G的目标就是尽量生成接近真实图片x去欺骗判别网络D。而D的目标就是尽量把G生成的图片G(z)和真实图片x分别开来。在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z))=0.5。这样我们的目的就达成了:我们得到了一个生成网络G,它可以用来生成图片。
但是,实际应用中,由于上述生成式对抗网络在训练中由于存在的参数量过多,对内存和计算量要求都很高,在对高维数据样本的处理能力方面明显速度慢、效率低、并且输出的图片质量很差,达不到预定要求。另外,上述生成式对抗网络模型是通过最小化真实图片x和生成图片G(z)的数据分布之间的距离来拟合生成网络的,在生成式对抗网络中已经使用的用来度量两个分布之间的距离包括:Jensen-Shannon散度、f-散度、以及Wasserstein距离,选择合适的距离来度量对生成网络的训练也是非常重要的。
鉴于此,在可能的实施方式中,本申请提出一种生成网络的训练方法,请参阅图2,显示为本申请在一实施例中的生成网络训练方法的模型框架示意图,如图所示,本申请生成网络的训练方法通过设置一映射网络F与生成网络G对抗来达到训练生成网络的目的,映射网络F和生成网络G均是人工神经网络,其中,生成网络G具有第一参数
Figure 830174DEST_PATH_IMAGE001
,映射网络F具有第二参数
Figure 378967DEST_PATH_IMAGE002
。请参阅图3,显示为本申请在一实施例中的生成网络的训练方法的流程图,如图所示,所述生成网络的训练方法包括步骤S10、步骤S11、步骤S12、以及步骤S13。
在步骤S10中,获取真实样本和生成样本,所述生成样本是由使用第一参数的生成网络所生成的。
其中,映射网络F获取真实样本X和生成样本Y
结合图1,所述真实样本X为生成网络G的目标样本或参考样本。例如,真实样本X为一个或多个遵循真实数据分布Pr的真实数据x的集合,表示为
Figure 404692DEST_PATH_IMAGE003
,则以真实样本X作为训练目标,期望生成网络G具有能够模拟真实数据分布Pr的能力,从而使得其生成的生成样本Y,表示为
Figure 671725DEST_PATH_IMAGE004
,能够逼近于真实样本X。其中,假设生成网络G生成的生成样本Y遵循生成数据分布Pg,也即,期望生成数据分布Pg与真实数据分布Pr能够相同。在实际应用中,所述真实样本X可例如为参考图像(可为现有的一张图像)或参考文本等,生成样本Y则为生成网络G基于第一参数
Figure 605046DEST_PATH_IMAGE001
而输出的生成图像或生成文本等,对生成网络G的训练即为对第一参数
Figure 273925DEST_PATH_IMAGE001
的调整以使得生成图像能够无限接近于参考图像,或者生成文本能够无限接近于参考文本。
所述生成样本Y是生成网络G由使用第一参数
Figure 532868DEST_PATH_IMAGE001
的生成网络G将输入的数据转换而生成的,是故,为了能够获取到生成样本Y,所述训练方法还包括生成网络G获取一随机变量Z的步骤,所述生成样本Y即为生成网络G基于所述随机变量Z作为输入数据而生成的,即Y=G(Z)。其中,所述随机变量Z为遵循于正太分布或均匀分布的一个或多个随机数z,其数量对应于真实样本X中真实数据x的数量,表示为
Figure 208568DEST_PATH_IMAGE005
Figure 211159DEST_PATH_IMAGE006
在步骤S11中,利用所述真实样本X和所述生成样本Y对映射网络F的第二参数
Figure 734545DEST_PATH_IMAGE002
进行训练,直至所训练的映射网络F使得所述真实样本X和所述生成样本Y之间的特征距离符合最大化条件。
进一步地,在步骤S11中,请参阅图4,显示为本申请在一实施例中的对映射网络的第二参数进行训练的流程图,如图所示,所述利用真实样本X和生成样本Y对映射网络F的第二参数
Figure 898810DEST_PATH_IMAGE002
进行训练的步骤包括步骤S110、步骤S111、以及步骤S112。
在步骤S110中,计算经所述映射网络F得到的真实样本F(X)和生成样本F(Y)之间的特征距离。
所述映射网络F作为额外引入的一个人工神经网络,可以帮助生成网络G更好的模拟真实数据分布Pr,从而最终使得生成网络G的生成数据分布Pg与真实样本的数据分布Pr之间的差异为零或接近于零。在一些实施例中,所述映射网络F用于将真实样本X和生成样本Y映射到特征表示空间内,也即,映射网络F用于学习其输入数据的特征表示,以特征的方式来表示真实样本X和生成样本Y。例如在图像处理领域,所述映射网络F可采用一卷积神经网络(convolutional neural network,CNN)来学习图像的特征表示。
更进一步地,为了降低计算量以及内存需求,所述映射网络F通过降维操作的方式将真实样本X和生成样本Y映射到特征表示空间内,具体的,可以采用特征选择或特征提取的方式来实现特征降维。其中,所述特征选择为从高维度的特征中选择其中的一个子集来作为新的特征,所述特征提取是指将高维度的特征经过某个降维函数映射至低维度作为新的特征。举例来说,特征提取方法包括主成分分析法(PCA)、奇异值分解法(SVD)、线性判别分析法(LDA)等,但考虑到降维效果和对样本标注的需求,本实施例中可以采用PCA来进行特征降维。但并不以此为限,还可以使用降维函数来实现特征降维。
是故,经所述映射网络F得到的真实样本F(X)和生成样本F(Y),即为映射网络F对真实样本X和生成样本Y执行上述处理之后得到的对应于真实样本X的低维度特征和对应于生成样本Y的低纬度特征。对应到真实样本X和生成样本Y所遵循的数据分布上,上述过程记为F(Pr)和F(Pg),并且F(X)遵循经映射网络F处理后的真实数据分布F(Pr),F(Y)遵循经映射网络F处理后的生成数据分布F(Pg)。如此,度量的真实样本X所遵循的真实数据分布Pr与生成样本Y所遵循的生成数据Pg之间距离,等价于度量F(X)遵循的真实数据分布F(Pr)和F(Y)遵循的生成数据分布F(Pg)之间的距离,由于F(X)、F(Pr)、F(Y)、以及F(Pg)的维度较低,故而,使得大大降低计算量,减轻了计算平台的负担。
为了能够更好的评估生成网络G用于生成样本Y的性能,在本申请中,采用最大均值差异(Maximum Mean Discrepancy,MMD)来度量经所述映射网络F得到的真实样本F(X)和生成样本F(Y)之间的距离,对应到真实样本X和生成样本Y所遵循的数据分布上,也即是采用最大均值差异MMD度量F(Pr)和F(Pg)之间的距离。
其中,最大均值差异MMD是通过寻找在样本空间上的连续函数f,分别求来自两个不同数据分布的样本(如遵循F(Pr)的F(X)和遵循F(Pg)的F(Y))在f函数上的函数值的均值,对两个均值作差可以得到这两个分布(F(Pr)和F(Pg))对应于f的均值差异(meandiscrepancy,MD),确定一个f使得均值差异MD有最大值,就得到了最大均值差异MMD。其中,样本空间例如为希尔伯特空间。以公式来定义,则映射网络F处理后真实数据分布F(Pr)和生成数据分布F(Pg)之间的最大均值差异MMD定义为:
Figure 140435DEST_PATH_IMAGE007
其中,
Figure 884400DEST_PATH_IMAGE008
表示遵循真实数据分布F(Pr)的真实样本F(X)在f函数上期望值,
Figure 262292DEST_PATH_IMAGE009
表示遵循生成数据分布F(Pg)的真实样本F(Y)在f函数上期望值。故而,
Figure 863038DEST_PATH_IMAGE010
表示在希尔伯特空间H的单位球中求取F(X)和F(Y)在函数f上的最大均值差异MMD。由于,映射网络F基于第二参数
Figure 326380DEST_PATH_IMAGE002
对其获取的样本进行处理,生成网络G基于第一参数
Figure 670774DEST_PATH_IMAGE001
对其输入数据进行处理而生成的生成样本Y,因此,
Figure 90123DEST_PATH_IMAGE011
既相关于映射网络F的第二参数
Figure 596190DEST_PATH_IMAGE002
,也相关于生成网络G的第一参数
Figure 812408DEST_PATH_IMAGE001
进一步地,为了求解上述公式,通过高斯内核k以及期望值的展开,上述公式可以采用下式来计算,如下:
Figure 694913DEST_PATH_IMAGE012
其中,
Figure 47397DEST_PATH_IMAGE013
等价于
Figure 724366DEST_PATH_IMAGE014
Figure 99984DEST_PATH_IMAGE015
遵循经映射网络F处理后的真实数据分布F(Pr),
Figure 786180DEST_PATH_IMAGE016
遵循于真实数据分布Pr
Figure 993171DEST_PATH_IMAGE017
等价于
Figure 106620DEST_PATH_IMAGE018
Figure 31851DEST_PATH_IMAGE019
遵循经映射网络F处理后的生成数据分布F(Pg),
Figure 440180DEST_PATH_IMAGE020
遵循于真实数据分布Pg。故而,
Figure 767256DEST_PATH_IMAGE021
等价于
Figure 51607DEST_PATH_IMAGE022
,其同前式
Figure 198554DEST_PATH_IMAGE023
,既相关于映射网络F的第二参数,也相关于生成网络G的第一参数
Figure 226553DEST_PATH_IMAGE001
因此,在本实施例中,采用
Figure 345819DEST_PATH_IMAGE024
来评估经所述映射网络F得到的真实样本F(X)和生成样本F(Y)之间的差异,以此达到训练生成网络G的目的。
从生成网络G的训练目标上来看,期望生成网络G对于任意设置的映射网络F,都能使得生成样本F(Y)逼近于真实样本F(X),也即
Figure 801071DEST_PATH_IMAGE024
达到最小(最好为零或接近于零)。为了能够达到上述目标,我们只需先确定能够使得
Figure 700894DEST_PATH_IMAGE024
最大的映射网络F,如果训练的生成网络G能够使得该最大的
Figure 267004DEST_PATH_IMAGE025
取得最小值(最好为零或接近零),那么对任意设置的映射网络F,都能使得
Figure 303094DEST_PATH_IMAGE025
达到最小。因此,在当前生成网络G的第一参数
Figure 116198DEST_PATH_IMAGE001
的条件下,首先需要训练映射网络F的第二参数
Figure 237737DEST_PATH_IMAGE002
能使得
Figure 607539DEST_PATH_IMAGE024
达到最大,具体步骤如下:
在步骤S111中,基于所述特征距离确定所述第二参数的偏移量,以更新所述第二参数。
在此,所述特征距离为上述经映射网络F得到的真实样本F(X)和生成样本F(Y)之间的最大均值差异MMD,即
Figure 763714DEST_PATH_IMAGE026
。根据前述,以使得
Figure 560768DEST_PATH_IMAGE027
达到最大为目标来训练映射网络F的第二参数
Figure 107287DEST_PATH_IMAGE002
。其中,求解
Figure 15201DEST_PATH_IMAGE028
的最大值表示为:
Figure 25882DEST_PATH_IMAGE029
Figure 993838DEST_PATH_IMAGE030
表示为F的候选函数的集合。
于实际应用中,
Figure 89970DEST_PATH_IMAGE029
的求解中可能是无界的,为了解决这个问题,在该公式中加入正则项,如下所示:
Figure 988525DEST_PATH_IMAGE031
其中,
Figure 119292DEST_PATH_IMAGE032
表示正则项的权重,Reg为正则项,在实际应用中会由于采用的正则化方法不同而有所不同。
在一实施例中,采用标准梯度惩罚正则化的方法使得
Figure 258149DEST_PATH_IMAGE029
有界,则上式中Reg为标准梯度惩罚正则项GP,如下式所示:
Figure 575998DEST_PATH_IMAGE033
其中,
Figure 91293DEST_PATH_IMAGE034
。在本申请中,限制
Figure 14250DEST_PATH_IMAGE035
满足1-Lipschitz,
Figure 324008DEST_PATH_IMAGE036
表示映射网络F所应用的函数的第i项,d表示映射网络F最终映射到的特征表示空间的维数,
Figure 129153DEST_PATH_IMAGE037
则表示沿一条直线在真实数据分布Pr和生成数据分布Pg上采样的点对,从而利用梯度惩罚正则化方法求解
Figure 182560DEST_PATH_IMAGE038
的最大值,以此来更新映射网络F的第二参数
Figure 22340DEST_PATH_IMAGE002
在另一些实施例中,采用L1正则化或L2正则化来使得
Figure 689950DEST_PATH_IMAGE029
有界,则上式中的Reg可例如为L1正则项
Figure 247971DEST_PATH_IMAGE039
L2正则项
Figure 839489DEST_PATH_IMAGE040
。其中,
Figure 799355DEST_PATH_IMAGE041
表示映射网络F的归一化层的参数,也可理解为在此,映射网络F的第二参数
Figure 450916DEST_PATH_IMAGE002
由其归一化层的参数来表示,因为归一化层决定了归一化输出在非线性激活函数之前的比例。
如此,可以基于
Figure 168336DEST_PATH_IMAGE042
确定第二参数
Figure 563545DEST_PATH_IMAGE002
的偏移量(即为每次更新第二参数
Figure 377918DEST_PATH_IMAGE002
所变动的数值)。
在一实施例中,所述偏移量为基于
Figure 200380DEST_PATH_IMAGE042
在第二参数
Figure 467413DEST_PATH_IMAGE002
上的梯度而得到的偏移量,其中
Figure 856194DEST_PATH_IMAGE043
在第二参数
Figure 525073DEST_PATH_IMAGE002
上的梯度表示为:
Figure 518437DEST_PATH_IMAGE044
更进一步地,本实施例中采用优化器基于上述梯度来确定第二参数的偏移量,以 采用Adam优化器为例,第二参数的偏移量为Adam(),其中,
Figure 142316DEST_PATH_IMAGE047
为 Adam优化器的初始参数。代表步长,例如可取值为0.001;和代表指数衰减率,其中, 用来控制权重分配,通常取接近于1的值,默认设置为0.9,用来控制梯度平方的影响情 况,默认设置为0.999。但本申请并不以此为限,Adam优化器的初始参数可以依据实际情况 进行选择,优化器的种类也不仅限于Adam优化器,还可使用RMSprop优化器、Adadelta优化 器等。
如此,基于第二参数ω的偏移量Adam(
Figure 660388DEST_PATH_IMAGE046
)更新第二参数
Figure 548709DEST_PATH_IMAGE002
,表示为:
Figure 431215DEST_PATH_IMAGE051
需要说明的是,为了对更新前后的第二参数以示区别,
Figure 783699DEST_PATH_IMAGE052
表示为更新之后的第二参数,
Figure 726247DEST_PATH_IMAGE002
则表示当前映射网络F(本次更新前)的第二参数,其中,偏移量和梯度也均是依据当前的映射网络F的第二参数计算所得。
在步骤S112中,利用更新后的第二参数重复步骤S110和步骤S111,直至所得到的特征距离符合最大化条件。
在如步骤S111中更新一次第二参数后,更新后的第二参数
Figure 164182DEST_PATH_IMAGE052
可能并不是最优结果,也即,更新后的第二参数
Figure 37328DEST_PATH_IMAGE052
虽然使得
Figure 244319DEST_PATH_IMAGE028
增大,但并没有使其达到最大化条件。
鉴于此,需要利用更新后的第二参数
Figure 357768DEST_PATH_IMAGE052
重复步骤S111,直至
Figure 282999DEST_PATH_IMAGE027
达到最大值。也即,每次重复均是将最新一次更新后的
Figure 507307DEST_PATH_IMAGE052
作为当前映射网络F的第二参数
Figure 772066DEST_PATH_IMAGE002
以对映射网络F训练,也即,步骤S111中的第二参数
Figure 790838DEST_PATH_IMAGE002
更新公式可以以下式表示:
Figure 203365DEST_PATH_IMAGE053
其中,i表示第i次重复步骤S111。以下以需要重复三次步骤S111能够使得
Figure 231364DEST_PATH_IMAGE027
达到最大为例进行说明,第一次更新后的映射网络F的第二参数
Figure 412946DEST_PATH_IMAGE054
,第二次更新后的映射网络F的第二参数
Figure 868198DEST_PATH_IMAGE055
,第三次更新后的映射网络F的第二参数
Figure 689393DEST_PATH_IMAGE056
,以
Figure 255503DEST_PATH_IMAGE057
作为本轮中能使得
Figure 557172DEST_PATH_IMAGE027
达到最大的映射网络F的第二参数
Figure 183325DEST_PATH_IMAGE002
需要说明的是,以上重复步骤S111三次仅是为了说明步骤S111是如何重复的,并不是限制重复步骤S111三次为使得
Figure 304865DEST_PATH_IMAGE025
达到最大的判定标准。在一些实施例中,通过预设重复次数(例如100次)作为使得
Figure 612349DEST_PATH_IMAGE025
达到最大的标准,在重复步骤S111达到预设次数时,则认为
Figure 768524DEST_PATH_IMAGE022
达到最大;在另一些实施例中,将每次更新第二参数
Figure 565579DEST_PATH_IMAGE002
的偏移量是否低于预设值(例如,
Figure 174415DEST_PATH_IMAGE058
)作为使得
Figure 82328DEST_PATH_IMAGE022
达到最大的标准,每次更新后,第二参数
Figure 93009DEST_PATH_IMAGE002
的偏移量低于预设值,则认为每次更新对第二参数
Figure 268424DEST_PATH_IMAGE002
的影响很小,
Figure 364556DEST_PATH_IMAGE022
达到最大。实际应用中,本领域技术人员可以依据实际所想要达到的精度需求对达到最大值的标准进行设定,本申请并不以此为限。
在步骤S12中,利用训练后的映射网络对所述生成网络的第一参数进行训练,直至训练后的第二参数的条件下的所述真实样本和训练后的生成网络输出的生成样本之间的特征距离符合最小化条件。
步骤S11中训练的生成网络F的第二参数
Figure 76160DEST_PATH_IMAGE002
能够使得在当前生成网络G的第一参数
Figure 941348DEST_PATH_IMAGE059
的条件下经映射网络F处理的真实样本
Figure 80205DEST_PATH_IMAGE060
和生成样本
Figure 601316DEST_PATH_IMAGE061
之间的特征距离最大。在此,需要训练生成网络G的第一参数
Figure 116611DEST_PATH_IMAGE001
,使得真实样本
Figure 836305DEST_PATH_IMAGE060
和更新后的生成样本
Figure 146064DEST_PATH_IMAGE062
之间的特征距离符合最小化条件,从而使得生成网络G能够生成逼近于真实样本X的生成样本的目的。
需要说明的是,对第一参数
Figure 216788DEST_PATH_IMAGE001
进行一次更新并不能达到第二参数
Figure 270195DEST_PATH_IMAGE002
的条件下的所述真实样本X和训练后的生成网络G输出的生成样本
Figure 296925DEST_PATH_IMAGE063
之间的特征距离符合最小化条件。鉴于此,所述利用训练后的映射网络F对所述生成网络G的第一参数
Figure 777585DEST_PATH_IMAGE001
进行训练的步骤包括:在步骤S11训练后的第二参数
Figure 335606DEST_PATH_IMAGE002
的条件下,基于所述真实样本X和每次更新第一参数
Figure 927124DEST_PATH_IMAGE001
后的生成网络G输出的生成样本
Figure 621411DEST_PATH_IMAGE064
之间的特征距离而确定的所述第一参数
Figure 538551DEST_PATH_IMAGE001
的偏移量,更新所述生成网络的第一参数
Figure 255971DEST_PATH_IMAGE001
,直至所述真实样本X和训练后的生成网络输出的生成样本
Figure 651181DEST_PATH_IMAGE065
之间的特征距离符合所述最小化条件。
其中,n表示特征距离符合所述最小化条件时第一参数
Figure 465553DEST_PATH_IMAGE001
更新的次数,
Figure 288015DEST_PATH_IMAGE063
表示第j次更新第一参数
Figure 555049DEST_PATH_IMAGE001
后生成网络G输出的生成样本。
其中,在步骤S11训练后的第二参数
Figure 940899DEST_PATH_IMAGE002
的条件下,也即是指,所述真实样本X和每次更新第一参数
Figure 609778DEST_PATH_IMAGE001
后的生成网络G输出的生成样本
Figure 603142DEST_PATH_IMAGE063
之间的特征距离采用的是在步骤S11训练后的映射网络F对真实样本X和每次更新第一参数
Figure 357471DEST_PATH_IMAGE001
后的生成网络G输出的生成样本
Figure 94483DEST_PATH_IMAGE063
分别处理之后得到的
Figure 617868DEST_PATH_IMAGE066
Figure 985396DEST_PATH_IMAGE067
之间的最大均值差异MMD。
具体地,首先基于步骤S11中训练后的映射网络F得到的真实样本
Figure 227021DEST_PATH_IMAGE066
j-1次更新第一参数
Figure 502145DEST_PATH_IMAGE001
后的生成网络G输出的生成样本
Figure 145616DEST_PATH_IMAGE068
之间的
Figure 480782DEST_PATH_IMAGE069
确定第j次更新第一参数
Figure 209704DEST_PATH_IMAGE001
的偏移量。
在本实施例中,第j次更新第一参数
Figure 475469DEST_PATH_IMAGE001
的偏移量包括:基于
Figure 973446DEST_PATH_IMAGE069
在第j-1次更新的第一参数
Figure 745093DEST_PATH_IMAGE001
上的梯度而得到的偏移量,其中,
Figure 695732DEST_PATH_IMAGE069
在第j-1次更新的第一参数
Figure 578237DEST_PATH_IMAGE001
上的梯度表示为:
Figure 868404DEST_PATH_IMAGE070
更进一步地,本实施例中采用优化器基于上述梯度来确定第j次更新第一参数 的偏移量,以采用Adam优化器为例,第j次更新第一参数的偏移量为Adam(
Figure 142074DEST_PATH_IMAGE072
),其中,
Figure 255523DEST_PATH_IMAGE073
为Adam优化器的初始参数。代表步长,例如可取值为 0.001;和代表指数衰减率,其中,用来控制权重分配,通常取接近于1的值,默认设置 为0.9,用来控制梯度平方的影响情况,默认设置为0.999。但本申请并不以此为限, Adam优化器的初始参数可以依据实际情况进行选择,优化器的种类也不仅限于Adam优化 器,还可使用RMSprop优化器、Adadelta优化器等。
如此,基于第j次更新第一参数的偏移量Adam(
Figure 690254DEST_PATH_IMAGE075
)更新第一参数,表示为:
Figure 156188DEST_PATH_IMAGE076
以下以需要更新三次第一参数
Figure 457856DEST_PATH_IMAGE001
能够使得
Figure 818430DEST_PATH_IMAGE077
符合最小化条件为例进行说明,第一次更新后的生成网络G的第一参数
Figure 392500DEST_PATH_IMAGE078
,第二次更新后的生成网络G的第一参数
Figure 762301DEST_PATH_IMAGE079
,第三次更新后的生成网络G的第一参数
Figure 652897DEST_PATH_IMAGE080
,以
Figure 449952DEST_PATH_IMAGE081
作为本轮中能使得
Figure 58787DEST_PATH_IMAGE082
符合最小化条件的生成网络G的第一参数。
需要说明的是,以上更新三次第一参数
Figure 169963DEST_PATH_IMAGE059
仅是为了说明示意,并不是限制更新三次第一参数
Figure 180644DEST_PATH_IMAGE059
为使得
Figure 148600DEST_PATH_IMAGE083
符合最小化条件的判定标准。在一些实施例中,通过预设更新次数(例如100次)作为使得
Figure 244732DEST_PATH_IMAGE083
符合最小化条件的标准,在更新第一参数
Figure 956336DEST_PATH_IMAGE001
达到预设次数时,则认为
Figure 821524DEST_PATH_IMAGE083
符合最小化条件;在另一些实施例中,将每次更新第一参数
Figure 147332DEST_PATH_IMAGE001
的偏移量是否低于预设值(例如,
Figure 730760DEST_PATH_IMAGE084
)作为使得
Figure 980476DEST_PATH_IMAGE082
符合最小化条件的标准,每次更新后,第一参数
Figure 965749DEST_PATH_IMAGE001
的偏移量低于预设值,则认为每次更新对第一参数
Figure 213191DEST_PATH_IMAGE001
的影响很小,
Figure 283915DEST_PATH_IMAGE085
符合最小化条件,实际应用中,本领域技术人员可以依据实际所想要达到的精度需求对达到最大值的标准进行设定,本申请并不以此为限。
在步骤S13中,重复上述过程直至经训练得到的第一参数满足预设收敛条件。
经过上述步骤S11和步骤S12,映射网络F与生成网络G完成一次对抗,使得新的生成网络G相较于未训练前的具有能够生成更逼近于真实样本的生成样本的能力。在实施例中,为了更好的优化生成网络G的性能,会重复执行步骤S10至S12直到经训练得到的第一参数满足预设收敛条件。需要说明的是,每次重复过程中当前生成网络G的第一参数
Figure 337322DEST_PATH_IMAGE001
和当前映射网络F的第二参数
Figure 177102DEST_PATH_IMAGE002
分别为上次更新后的数值。
在一些实施例中,所述预设收敛条件为得到的第一参数与上一次重复步骤S10至步骤S12得到的第一参数的差值低于预设值。如此,每次重复步骤S10至步骤S12后,要判断本次与上次重复后得到的第一参数的差值是否低于预设值。如果高于预设值,则第一参数不能收敛,继续重复步骤S10至步骤S12;如果低于预设值,则认为训练的第一参数已经趋于稳定了(或者说已经可收敛了),对生成网络G的训练停止。
在另一些实施例中,所述预设收敛条件为重复步骤S10至步骤S12的次数达到预设次数,也即是,认为在执行预定次数的步骤S10至步骤S12后,得到的第一参数便会趋于稳定了,对生成网络G的训练也就结束了。
综上所述,本申请提出的生成网络的训练方法简单易行,并且大大降低了计算量和内存需求。另外训练出的生成网络性能稳定,能够生成更真实的图像,尤其在生成高维图像方面具有很好的表现力。
本申请还提出一种生成网络的训练系统,请参阅图5,显示为本申请在一实施例中的训练系统的结构框图,如图所示,所述生成网络的训练系统包括映射网络模块F、生成网络模块G、以及训练模块T。其中,所述生成网络模块G具有第一参数
Figure 657762DEST_PATH_IMAGE001
,用于基于第一参数
Figure 137154DEST_PATH_IMAGE001
输出生成样本Y,所述映射网络模块F具有第二参数
Figure 994251DEST_PATH_IMAGE002
,用于获取真实样本X和生成样本Y,并基于第二参数
Figure 688538DEST_PATH_IMAGE002
输出映射后的真实样本F(X)和生成样本F(Y),所述训练模块T用于对生成网络模块G的第一参数
Figure 340099DEST_PATH_IMAGE001
和映射网络模块F的第二参数
Figure 385415DEST_PATH_IMAGE002
进行训练。
所述映射网络模块F具有第二参数
Figure 780625DEST_PATH_IMAGE002
,用于获取真实样本X和生成样本Y
如图5所示,映射网络模块F获取的真实样本X为生成网络模块G的目标样本或参考样本。例如,真实样本X为一个或多个遵循真实数据分布Pr的真实数据x的集合,表示为
Figure 267101DEST_PATH_IMAGE086
,则以真实样本X作为训练目标,期望生成网络模块G具有能够模拟真实数据分布Pr的能力,从而使得其生成的生成样本Y,表示为
Figure 355142DEST_PATH_IMAGE087
,能够逼近于真实样本X。其中,假设生成网络模块G生成的生成样本Y遵循生成数据分布Pg,也即,期望生成数据分布Pg与真实数据分布Pr能够相同。在实际应用中,所述真实样本X可例如为参考图像(可为现有的一张图像)或参考文本等,生成样本Y则为生成网络模块G基于第一参数
Figure 622176DEST_PATH_IMAGE001
而输出的生成图像或生成文本等,对生成网络模块G的训练即为对第一参数
Figure 821076DEST_PATH_IMAGE001
的调整以使得生成图像能够无限接近于参考图像,或者生成文本能够无限接近于参考文本。
映射网络模块F获取的生成样本Y是生成网络模块G由使用第一参数
Figure 224375DEST_PATH_IMAGE001
的生成网络模块G将输入的数据转换而生成的,是故,为了能够获取到生成样本Y,所述训练方法还包括生成网络G获取一随机变量Z的步骤,所述生成样本Y即为生成网络G基于所述随机变量Z作为输入数据而生成的,即Y=G(Z)。其中,所述随机变量Z为遵循于正太分布或均匀分布的一个或多个随机数z,其数量对应于真实样本X中真实数据x的数量,表示为
Figure 483318DEST_PATH_IMAGE088
Figure 421669DEST_PATH_IMAGE089
在一实施例中,如图5所示,所述映射网络模块F包括特征表示单元,所述特征表示单元用于将真实样本X和生成样本Y映射到特征表示空间内,也即,特征表示单元用于学习其输入数据的特征表示,以特征的方式来表示真实样本X和生成样本Y
更进一步地,为了降低计算量以及内存需求,所述特征表示单元还对真实样本X的特征和生成样本Y的特征进行降维处理。具体的,可以采用特征选择或特征提取的方式来实现特征降维。其中,所述特征选择为从高维度的特征中选择其中的一个子集来作为新的特征,所述特征提取是指将高维度的特征经过某个降维函数映射至低维度作为新的特征。举例来说,特征提取方法包括主成分分析法(PCA)、奇异值分解法(SVD)、线性判别分析法(LDA)等,但考虑到降维效果和对样本标注的需求,本实施例中可以采用PCA来进行特征降维。但并不以此为限,还可以使用降维函数来实现特征降维。
是故,经所述映射网络模块F得到的真实样本F(X)和生成样本F(Y),即为映射网络模块F对真实样本X和生成样本Y执行上述处理之后得到的对应于真实样本X的低维度特征和对应于生成样本Y的低纬度特征。对应到真实样本X和生成样本Y所遵循的数据分布上,上述过程记为F(Pr)和F(Pg),并且F(X)遵循经映射网络模块F处理后的真实数据分布F(Pr),F(Y)遵循经映射网络模块F处理后的生成数据分布F(Pg)。如此,后续训练模块T在度量真实样本X和生成样本Y之间的距离时,等价于度量经所述映射网络模块F得到的真实样本F(X)和生成样本F(Y)之间的距离,或在度量的真实样本X所遵循的真实数据分布Pr与生成样本Y所遵循的生成数据Pg之间距离时,等价于度量F(X)遵循的真实数据分布F(Pr)和F(Y)遵循的生成数据分布F(Pg)之间的距离,由于F(X)、F(Pr)、F(Y)、以及F(Pg)的维度较低,故而,使得大大降低计算量,减轻了计算平台的负担。
以下结合图5来说明训练模块T是如何对生成网络模块G的第一参数
Figure 158681DEST_PATH_IMAGE001
和映射网络模块F的第二参数
Figure 682066DEST_PATH_IMAGE002
进行训练的。
具体地,训练模块T基于第一参数
Figure 111910DEST_PATH_IMAGE001
对映射网络模块F的第二参数
Figure 87956DEST_PATH_IMAGE002
进行训练,直至所训练的映射网模块络F使得所述真实样本X和所述生成样本Y之间的特征距离符合最大化条件。其中,所述特征距离即为经所述映射网络F得到的真实样本F(X)和生成样本F(Y)之间的距离。
如图5所示,所述训练模块T包括计算单元,用于计算经所述映射网络F得到的真实样本F(X)和生成样本F(Y)之间的距离。
为了能够更好的评估生成网络模块G用于生成样本Y的性能,在本申请中,采用最大均值差异MMD来度量经所述映射网络模块F得到的真实样本F(X)和生成样本F(Y)之间的距离,对应到真实样本X和生成样本Y所遵循的数据分布上,也即是采用最大均值差异MMD度量F(Pr)和F(Pg)之间的距离。
其中,最大均值差异MMD是通过寻找在样本空间上的连续函数f,分别求来自两个不同数据分布的样本(如遵循F(Pr)的F(X)和遵循F(Pg)的F(Y))在f函数上的函数值的均值,对两个均值作差可以得到这两个分布(F(Pr)和F(Pg))对应于f的均值差异(meandiscrepancy,MD),确定一个f使得均值差异MD有最大值,就得到了最大均值差异MMD。其中,样本空间例如为希尔伯特空间。以公式来定义,则映射网络模块F处理后真实数据分布F(Pr)和生成数据分布F(Pg)之间的最大均值差异MMD定义为:
Figure 566342DEST_PATH_IMAGE007
其中,
Figure 209813DEST_PATH_IMAGE090
表示遵循真实数据分布F(Pr)的真实样本F(X)在f函数上期望值,
Figure 544980DEST_PATH_IMAGE091
表示遵循生成数据分布F(Pg)的真实样本F(Y)在f函数上期望值。故而,
Figure 273901DEST_PATH_IMAGE092
表示在希尔伯特空间H的单位球中求取F(X)和F(Y)在函数f上的最大均值差异MMD。由于,映射网络模块F基于第二参数
Figure 352716DEST_PATH_IMAGE002
对其获取的样本进行处理,生成网络模块G基于第一参数
Figure 850693DEST_PATH_IMAGE001
对其输入数据进行处理而生成的生成样本Y,因此,
Figure 543712DEST_PATH_IMAGE093
既相关于映射网络模块F的第二参数
Figure 759929DEST_PATH_IMAGE002
,也相关于生成网络模块G的第一参数
Figure 642435DEST_PATH_IMAGE001
进一步地,为了求解上述公式,通过高斯内核k以及期望值的展开,上述公式可以采用下式来计算,如下:
Figure 994919DEST_PATH_IMAGE094
其中,
Figure 937467DEST_PATH_IMAGE013
等价于
Figure 375401DEST_PATH_IMAGE014
Figure 733701DEST_PATH_IMAGE015
遵循经映射网络F处理后的真实数据分布F(Pr),
Figure 206271DEST_PATH_IMAGE016
遵循于真实数据分布Pr
Figure 54141DEST_PATH_IMAGE017
等价于
Figure 979372DEST_PATH_IMAGE018
Figure 203680DEST_PATH_IMAGE019
遵循经映射网络F处理后的生成数据分布F(Pg),
Figure 717707DEST_PATH_IMAGE020
遵循于真实数据分布Pg。故而,
Figure 2058DEST_PATH_IMAGE021
等价于
Figure 414584DEST_PATH_IMAGE022
,其同前式
Figure 177004DEST_PATH_IMAGE023
,既相关于映射网络F的第二参数
Figure 358587DEST_PATH_IMAGE002
,也相关于生成网络G的第一参数
Figure 813839DEST_PATH_IMAGE001
因此,在本实施例中,采用
Figure 651345DEST_PATH_IMAGE024
来评估经所述映射网络F得到的真实样本F(X)和生成样本F(Y)之间的差异,以此达到训练第一参数
Figure 217455DEST_PATH_IMAGE001
和第二参数
Figure 253545DEST_PATH_IMAGE002
的目的。
从生成网络模块G的第一参数
Figure 879698DEST_PATH_IMAGE059
的训练目标上来看,期望生成网络模块G对于任意设置的映射网络模块F,都能使得生成样本F(Y)逼近于真实样本F(X),也即
Figure 266817DEST_PATH_IMAGE022
达到最小(最好为零或接近于零)。为了能够达到上述目标,我们只需先确定能够使得
Figure 557990DEST_PATH_IMAGE022
最大的映射网模块络F,如果训练的生成网络农垦G能够使得该最大的
Figure 714165DEST_PATH_IMAGE022
取得最小值(最好为零或接近零),那么对任意设置的映射网络模块F,都能使得
Figure 511219DEST_PATH_IMAGE095
达到最小。因此,在当前生成网络模块G的第一参数
Figure 120055DEST_PATH_IMAGE001
的条件下,首先训练模块T训练映射网络模块F的第二参数
Figure 293548DEST_PATH_IMAGE002
能使得
Figure 304229DEST_PATH_IMAGE022
符合最大化条件。
如图5,所述训练模块T还包括第二更新单元,第二更新单元基于经所述映射网络F得到的真实样本F(X)和生成样本F(Y)之间的距离确定所述第二参数
Figure 209868DEST_PATH_IMAGE002
的偏移量,以更新所述第二参数。
在此,所述经所述映射网络F得到的真实样本F(X)和生成样本F(Y)之间的距离,即为
Figure 306000DEST_PATH_IMAGE025
。根据前述,以使得
Figure 752025DEST_PATH_IMAGE022
达到最大为目标来训练第二参数
Figure 882792DEST_PATH_IMAGE002
。其中,求解
Figure 21649DEST_PATH_IMAGE025
的最大值表示为:
Figure 794958DEST_PATH_IMAGE029
Figure 44673DEST_PATH_IMAGE030
表示为F的候选函数的集合。
于实际应用中,
Figure 29947DEST_PATH_IMAGE029
的求解中可能是无界的,为了解决这个问题,在该公式中加入正则项,如下所示:
Figure 339706DEST_PATH_IMAGE031
其中,
Figure 410430DEST_PATH_IMAGE032
表示正则项的权重,Reg为正则项,在实际应用中会由于采用的正则化方法不同而有所不同。
在一实施例中,采用标准梯度惩罚正则化的方法使得
Figure 463836DEST_PATH_IMAGE029
有界,则上式中Reg为标准梯度惩罚正则项GP,如下式所示:
Figure 241300DEST_PATH_IMAGE033
其中,
Figure 721959DEST_PATH_IMAGE096
。在本申请中,限制
Figure 14401DEST_PATH_IMAGE035
满足1-Lipschitz,
Figure 871498DEST_PATH_IMAGE097
表示映射网络F所应用的函数的第i项,d表示映射网络F最终映射到的特征表示空间的维数,
Figure 565785DEST_PATH_IMAGE098
则表示沿一条直线在真实数据分布Pr和生成数据分布Pg上采样的点对,从而利用梯度惩罚正则化方法求解
Figure 217346DEST_PATH_IMAGE099
的最大值,以此来更新映射网络F的第二参数
Figure 449613DEST_PATH_IMAGE002
在另一些实施例中,采用L1正则化或L2正则化来使得
Figure 844822DEST_PATH_IMAGE029
有界,则上式中的Reg可例如为L1正则项
Figure 393615DEST_PATH_IMAGE100
L2正则项
Figure 481657DEST_PATH_IMAGE101
。其中,
Figure 748690DEST_PATH_IMAGE041
表示映射网络F的归一化层的参数,也可理解为在此,映射网络F的第二参数
Figure 682011DEST_PATH_IMAGE002
由其归一化层的参数来表示,因为归一化层决定了归一化输出在非线性激活函数之前的比例。
如此,可以基于
Figure 288573DEST_PATH_IMAGE102
确定第二参数
Figure 547516DEST_PATH_IMAGE002
的偏移量(即为每次更新第二参数
Figure 36266DEST_PATH_IMAGE002
所变动的数值)。
在一实施例中,所述偏移量为基于
Figure 38857DEST_PATH_IMAGE102
在第二参数
Figure 562242DEST_PATH_IMAGE002
上的梯度而得到的偏移量,其中
Figure 913458DEST_PATH_IMAGE103
在第二参数
Figure 155084DEST_PATH_IMAGE002
上的梯度表示为:
Figure 695786DEST_PATH_IMAGE044
更进一步地,本实施例中采用优化器基于上述梯度来确定第二参数的偏移量,以 采用Adam优化器为例,第二参数的偏移量为Adam(),其中,
Figure 917820DEST_PATH_IMAGE105
为 Adam优化器的初始参数。代表步长,例如可取值为0.001;和代表指数衰减率,其中, 用来控制权重分配,通常取接近于1的值,默认设置为0.9,用来控制梯度平方的影响情 况,默认设置为0.999。但本申请并不以此为限,Adam优化器的初始参数可以依据实际情况 进行选择,优化器的种类也不仅限于Adam优化器,还可使用RMSprop优化器、Adadelta优化 器等。
如此,第二更新单元基于第二参数
Figure 863146DEST_PATH_IMAGE002
的偏移量Adam(
Figure 70136DEST_PATH_IMAGE046
)更新第二参数
Figure 183586DEST_PATH_IMAGE002
,表示为:
Figure 46499DEST_PATH_IMAGE051
其中,为了对更新前后的第二参数以示区别,
Figure 270807DEST_PATH_IMAGE052
表示为更新之后的第二参数,
Figure 597883DEST_PATH_IMAGE002
则表示当前映射网络模块F(本次更新前)的第二参数,其中,偏移量和梯度也均是依据当前的映射网络模块F的第二参数计算所得。
需要说明的是,在第二更新单元更新一次第二参数后,更新后的第二参数
Figure 882234DEST_PATH_IMAGE052
可能并不是最优结果,也即,更新后的第二参数
Figure 29182DEST_PATH_IMAGE052
虽然使得
Figure 241202DEST_PATH_IMAGE022
增大,但并没有使其达到最大化条件。故而,第二更新单元会利用每次更新后的第二参数重复上述过程,鉴于此,第二参数
Figure 422784DEST_PATH_IMAGE002
更新公式可以以下式表示:
Figure 878036DEST_PATH_IMAGE106
其中,i表示第二更新单元第i次更新第二参数
Figure 777859DEST_PATH_IMAGE002
。以下以更新三次第二参数
Figure 343970DEST_PATH_IMAGE002
为例进行说明,第一次更新后的映射网络模块F的第二参数
Figure 380059DEST_PATH_IMAGE107
,第二次更新后的映射网络模块F的第二参数
Figure 943895DEST_PATH_IMAGE108
,第三次更新后的映射网络模块F的第二参数
Figure 65435DEST_PATH_IMAGE109
需要说明的是,以上更新三次第二参数
Figure 435237DEST_PATH_IMAGE002
仅是为了说明第二参数
Figure 591412DEST_PATH_IMAGE002
是如何重复更新的,并不是指第二更新单元仅更新三次第二参数
Figure 388466DEST_PATH_IMAGE002
鉴于此,如图5,训练模块T还包括第二判别单元,用于在判断利用更新后的第二参数所得到的特征距离
Figure 997302DEST_PATH_IMAGE025
符合最大化条件时停止第二更新单元的更新。
在一些实施例中,通过预设更新次数(例如100次)作为使得
Figure 92166DEST_PATH_IMAGE022
达到最大化条件的标准,在第二更新单元的更新次数达到预设次数时,则认为
Figure 102847DEST_PATH_IMAGE022
达到最大;在另一些实施例中,将每次更新第二参数
Figure 70803DEST_PATH_IMAGE002
的偏移量是否低于预设值(例如,
Figure 166935DEST_PATH_IMAGE110
)作为使得
Figure 878539DEST_PATH_IMAGE025
达到最大化条件的标准,每次更新后,第二参数
Figure 946989DEST_PATH_IMAGE002
的偏移量低于预设值,则认为每次更新对第二参数
Figure 85847DEST_PATH_IMAGE002
的影响很小,
Figure 403696DEST_PATH_IMAGE024
达到最大。实际应用中,本领域技术人员可以依据实际所想要达到的精度需求对达到最大值的标准进行设定,本申请并不以此为限。
训练模块T还用于基于经上述训练后的映射网络模块F对所述生成网络模块G的第一参数
Figure 918991DEST_PATH_IMAGE001
进行训练,直至训练后的第二参数
Figure 904264DEST_PATH_IMAGE002
的条件下的所述X真实样本和训练后的生成网络输出的生成样本Y之间的特征距离符合最小化条件。
前述中,训练模块T训练的生成网络模块F的第二参数
Figure 214023DEST_PATH_IMAGE002
能够使得在当前生成网络模块G的第一参数
Figure 206118DEST_PATH_IMAGE001
的条件下经映射网络模块F处理的真实样本
Figure 259525DEST_PATH_IMAGE066
和生成样本
Figure 99305DEST_PATH_IMAGE061
之间的特征距离最大。在此,需要训练生成网络模块G的第一参数
Figure 579965DEST_PATH_IMAGE001
,使得真实样本
Figure 137985DEST_PATH_IMAGE066
和更新后的生成样本
Figure 667187DEST_PATH_IMAGE067
之间的距离符合最小化条件,从而使得生成网络G能够生成逼近于真实样本X的生成样本的目的。其中对真实样本
Figure 627052DEST_PATH_IMAGE066
和更新后的生成样本
Figure 278614DEST_PATH_IMAGE062
之间的距离计算依然由前述计算单元执行。
需要说明的是,对第一参数
Figure 58351DEST_PATH_IMAGE001
进行一次更新并不能达到第二参数
Figure 453560DEST_PATH_IMAGE002
的条件下的所述真实样本X和训练后的生成网络模块G输出的生成样本
Figure 267932DEST_PATH_IMAGE063
之间的特征距离符合最小化条件。鉴于此,如图5,所述训练模块T还包括第一更新单元和第一判别单元。
所述第一更新单元用于在训练后的第二参数
Figure 277346DEST_PATH_IMAGE002
的条件下,基于所述真实样本X和每次更新第一参数
Figure 544379DEST_PATH_IMAGE001
后的生成网络G输出的生成样本
Figure 743279DEST_PATH_IMAGE111
之间的特征距离而确定的所述第一参数
Figure 412158DEST_PATH_IMAGE001
的偏移量,更新第一参数
Figure 405522DEST_PATH_IMAGE001
,直至所述真实样本X和训练后的生成网络模块G输出的生成样本
Figure 159851DEST_PATH_IMAGE065
之间的特征距离符合所述最小化条件。其中,n表示特征距离符合所述最小化条件时第一参数
Figure 100125DEST_PATH_IMAGE001
更新的次数,
Figure 623510DEST_PATH_IMAGE063
表示第j次更新第一参数后
Figure 787775DEST_PATH_IMAGE001
生成网络模块G输出的生成样本。
其中,在训练后的第二参数
Figure 29401DEST_PATH_IMAGE002
的条件下,也即是指,所述真实样本X和每次更新第一参数
Figure 570104DEST_PATH_IMAGE001
后的生成网络G输出的生成样本
Figure 947995DEST_PATH_IMAGE063
之间的特征距离采用的是在步骤S11训练后的映射网络模块F对真实样本X和每次更新第一参数
Figure 738621DEST_PATH_IMAGE001
后的生成网络G输出的生成样本
Figure 201964DEST_PATH_IMAGE063
分别处理之后得到的
Figure 546357DEST_PATH_IMAGE060
Figure 778756DEST_PATH_IMAGE067
之间的最大均值差异MMD。
具体地,第一更新单元首先基于训练后的映射网络模块F得到的真实样本
Figure 550403DEST_PATH_IMAGE060
j-1次更新第一参数
Figure 438724DEST_PATH_IMAGE001
后的生成网络G输出的生成样本
Figure 321229DEST_PATH_IMAGE068
之间的
Figure 673713DEST_PATH_IMAGE069
确定第j次更新第一参数
Figure 616262DEST_PATH_IMAGE001
的偏移量。
在本实施例中,第j次更新第一参数
Figure 54196DEST_PATH_IMAGE059
的偏移量包括:基于
Figure 740392DEST_PATH_IMAGE069
在第j-1次更新的第一参数
Figure 134334DEST_PATH_IMAGE001
上的梯度而得到的偏移量,其中,
Figure 247783DEST_PATH_IMAGE069
在第j-1次更新的第一参数
Figure 173014DEST_PATH_IMAGE001
上的梯度表示为:
Figure 397322DEST_PATH_IMAGE112
更进一步地,本实施例中采用优化器基于上述梯度来确定第j次更新第一参数 的偏移量,以采用Adam优化器为例,第j次更新第一参数的偏移量为Adam(
Figure 121378DEST_PATH_IMAGE114
),其中,
Figure 302961DEST_PATH_IMAGE047
为Adam优化器的初始参数。代表步长,例如可取值为 0.001;和代表指数衰减率,其中,用来控制权重分配,通常取接近于1的值,默认设置 为0.9,用来控制梯度平方的影响情况,默认设置为0.999。但本申请并不以此为限, Adam优化器的初始参数可以依据实际情况进行选择,优化器的种类也不仅限于Adam优化 器,还可使用RMSprop优化器、Adadelta优化器等。
如此,基于第j次更新第一参数的偏移量Adam(
Figure 658539DEST_PATH_IMAGE114
)更新第一参数,表示为:
Figure 64429DEST_PATH_IMAGE076
以下以第一更新单元更新三次第一参数
Figure 972342DEST_PATH_IMAGE001
为例进行说明,第一次更新后的生成网络模块G的第一参数
Figure 983024DEST_PATH_IMAGE115
,第二次更新后的生成网络模块G的第一参数
Figure 137931DEST_PATH_IMAGE116
,第三次更新后的生成网络模块G的第一参数
Figure 234062DEST_PATH_IMAGE117
需要说明的是,以上更新三次第一参数
Figure 945667DEST_PATH_IMAGE001
仅是为了说明示意,并不是限制第一更新单元仅更新三次第一参数
Figure 810854DEST_PATH_IMAGE001
鉴于此,所述第一判别单元用于在判断基于每次更新后的第一参数
Figure 949712DEST_PATH_IMAGE001
而输出的生成样本
Figure 533140DEST_PATH_IMAGE063
与所述真实样本X之间的特征距离
Figure 986118DEST_PATH_IMAGE118
符合最小化条件时停止第一更新单元的更新。
在一些实施例中,通过预设第一更新单元的更新次数(例如100次)作为使得符合最小化条件的标准,在更新第一参数达到预设次数时,则认为符合最小化条件;在另一些实施例中,将每次更新第一参数的偏移 量低于预设值(例如,
Figure 163503DEST_PATH_IMAGE121
)作为使得符合 最小化条件的标准,每次更新后,第一参数的偏移量低于预设值,则认为每次更新对第一 参数的影响很小,符合最小化条件,实际应用中,本领域技术人员可 以依据实际所想要达到的精度需求对达到最大值的标准进行设定,本申请并不以此为限。
经过训练模块T的上述处理,使得新的生成网络模块G相较于未训练前的具有能够生成更逼近于真实样本的生成样本的能力,新的映射网络模块F相较未训练前的具有辨别真实样本和生成样本的能力更强。在实施例中,为了更好的优化生成网络G的性能,训练模块T会重复上述训练过程直至经训练得到的第一参数满足预设收敛条件。需要说明的是,每次重复过程中当前生成网络模块G的第一参数
Figure 405128DEST_PATH_IMAGE001
和当前映射网络模块F的第二参数
Figure 122548DEST_PATH_IMAGE002
分别为上次更新后的数值。
鉴于此,所述训练模块T还可包括第三判别单元(未予以图示),用于在经训练所得到的第一参数满足预设收敛条件时停止训练模块T的更新,并将训练模块T停止更新后的第一参数
Figure 517758DEST_PATH_IMAGE001
反馈给生成网络模块G,以及将训练模块T停止更新后的第二参数
Figure 332130DEST_PATH_IMAGE002
反馈给映射网络模块F
在一些实施例中,所述预设收敛条件为本次得到的第一参数与上一次重复训练模块T的训练过程得到的第一参数的差值低于预设值。如此,每次重复训练模块T的训练过程后,要判断本次与上次重复后得到的第一参数的差值是否低于预设值。如果高于预设值,则第一参数不能收敛,继续重复训练模块T的训练过程;如果低于预设值,则认为训练的第一参数已经趋于稳定了(或者说已经可收敛了),则停止训练模块T的训练过程,将训练模块T停止更新后的第一参数反馈给生成网络模块G,以及将训练模块T停止更新后的第二参数
Figure 154592DEST_PATH_IMAGE002
反馈给映射网络模块F
在另一些实施例中,所述预设收敛条件为重复训练模块T的训练过程的次数达到预设次数,也即是,认为在执行预定次数的训练模块T的训练过程后,得到的第一参数便会趋于稳定了,训练模块T的训练过程即停止。
综上所述,本申请提出的生成网络的训练系统简单易行,并且大大降低了计算量和内存需求。另外训练出的生成网络模块性能稳定,能够生成更真实的图像,尤其在生成高维图像方面具有很好的表现力。
本申请还提出一种图像生成方法,请参阅图6,显示为本申请在一实施例中的图像生成方法的流程图,所述图像生成方法包括步骤S30和步骤S31。
在步骤S30中,获取一原始输入图像。
所述原始输入图像为待处理图像。所述待处理图像包括但不限于:通过使用手机、平板电脑、笔记本电脑等智能终端的摄像头进行拍摄或扫描得到图像、通过网络获取的电子图像、预先存储于所述各智能终端的图像。例如,所述待处理图像为摄像头扫描的老旧或污损照片、网络上下载的低分辨率的图像等。
在步骤S31中,利用本申请中生成网络训练方法获得的生成网络对所述原始输入图像进行处理以输出一生成图像。
将获取的原始输入图像输入至生成网络中,所述生成网络对原始输入图像进行重新生成,从而基于所述原始输入图像输出一生成图像。例如,在一示例中,所述原始输入图像为摄像头扫描的污损照片,经生成网络对污损照片进行处理,输出的生成图像为对还原后照片;在另一示例中,所述原始输入图像为网络上下载的低分辨率的图像,经生成网络对其处理,输出的生成图像为低分辨率的图像。
本申请还提出一种图像生成设备,请参阅图7,显示为本申请在一实施例中的图像生成设备结构示意图,如图所示,所述图像生成设备20包括图像采集装置21、存储装置22、和处理装置23,图像采集装置21、存储器22、和处理装置23可通过总线或其它方式通信连接,本申请通过总线连接为例,其中,图像采集装置21用于获取待处理的原始输入图像。
所述图像采集装置21包括智能终端或其它带有摄像头的装置或设备,所述原始输入图像包括但不限于:通过使用手机、平板电脑、笔记本电脑等智能终端的摄像头进行拍摄或扫描得到图像、通过智能终端下载的电子图像、预先存储于所述各智能终端的图像。例如,所述原始输入图像为摄像头扫描的老旧或污损照片、网络上下载的低分辨率的图像等。
所述存储装置22可包括随机存取存储器(Random Access Memory,RAM)、只读存储器(Read Only Memory,ROM)、可编程只读存储器(Programmable Read-Only Memory,PROM)、可擦可编程序只读存储器(Erasable Programmable Read-Only Memory,EPROM)、电可擦编程只读存储器Electric Erasable Programmable Read-Only Memory,EEPROM)等。存储装置22用于存储计算机程序,处理装置23在接收到执行指令后,执行该程序。
所述处理装置23包括集成电路芯片,具有信号处理能力;或通用处理器,例如,可以是数字信号处理器(DSP)、专用集成电路(ASIC)、分立门或晶体管逻辑器件、分立硬件组件,可以实现或者执行本申请实施例中的公开的各方法、步骤及逻辑框图。所述通用处理器可以是微处理器或者任何常规处理器等,例如中央处理器(Central Processing Unit,CPU)。所述处理装置23用于调用所述存储装置22中存储的计算机程序来执行前述的图像生成方法,在此不再赘述。
本申请提供的图像生成方法还可借由以图像生成客户端来执行,所述图像生成客户端装载于一智能设备中,通过智能设备的软件和硬件来实现。所述图像生成客户端在智能设备上的装载形式包括APP应用程序或小程序(例如微信上的小程序等)。
请参阅图8,显示为本申请在一实施例中的图像生成客户端结构示意图,如图所示,所述装载于一智能设备中的图像生成客户端30包括输入模块31和处理模块32。其中:所述输入模块31用于接收到用户输入的生成指令时,调用所述智能设备的图像采集装置获取待处理的原始输入图像;所述处理模块32调用所述智能设备中存储的计算机程序来执行前述的图像生成方法。
在一实施例中,所述图像生成客户端30还包括显示模块(未予以图示),用于显示处理模块32执行前述图像生成方法所输出生成图像。
所述智能设备例如为包括但不限于智能手机、平板电脑、智能手表、个人数字助理(PDA)等等的电子设备,应当理解,本申请于实施方式中描述的电子设备只是一个应用实例,该设备的组件可以比图示具有更多或更少的组件,或具有不同的组件配置。所绘制图示的各种组件可以用硬件、软件或软硬件的组合来实现,包括一个或多个信号处理和/或专用集成电路。
本申请还提供一种计算机可读存储介质,存储有至少一计算机程序,所述计算机程序被用于执行本申请中提出的生成网络的训练方法或执行实现本申请中提出的生成图像的方法。
于本申请提供的实施例中,所述计算机可读写存储介质可以包括只读存储器、随机存取存储器、EEPROM、CD-ROM或其它光盘存储装置、磁盘存储装置或其它磁存储设备、闪存、U盘、移动硬盘、或者能够用于存储具有指令或数据结构形式的期望的程序代码并能够由计算机进行存取的任何其它介质。另外,任何连接都可以适当地称为计算机可读介质。例如,如果指令是使用同轴电缆、光纤光缆、双绞线、数字订户线(DSL)或者诸如红外线、无线电和微波之类的无线技术,从网站、服务器或其它远程源发送的,则所述同轴电缆、光纤光缆、双绞线、DSL或者诸如红外线、无线电和微波之类的无线技术包括在所述介质的定义中。然而,应当理解的是,计算机可读写存储介质和数据存储介质不包括连接、载波、信号或者其它暂时性介质,而是旨在针对于非暂时性、有形的存储介质。如申请中所使用的磁盘和光盘包括压缩光盘(CD)、激光光盘、光盘、数字多功能光盘(DVD)、软盘和蓝光光盘,其中,磁盘通常磁性地复制数据,而光盘则用激光来光学地复制数据。
在一个或多个示例性方面,本申请所述生成网络的训练方法和图像生成方法的计算机程序所描述的功能可以用硬件、软件、固件或者其任意组合的方式来实现。当用软件实现时,可以将这些功能作为一个或多个指令或代码存储或传送到计算机可读介质上。本申请所公开的方法或算法的步骤可以用处理器可执行软件模块来体现,其中处理器可执行软件模块可以位于有形、非临时性计算机可读写存储介质上。有形、非临时性计算机可读写存储介质可以是计算机能够存取的任何可用介质。
本申请上述的附图中的流程图和框图,图示了按照本申请各种实施例的系统、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段、或代码的一部分,该模块、程序段、或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个接连地表示的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这根据所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以通过执行规定的功能或操作的专用的基于硬件的系统来实现,或者可以通过专用硬件与计算机指令的组合来实现。
本申请提供的生成网络的训练方法、生成网络的训练系统、图像生成方法、图像生成设备、图像生成客户端及计算机可读存储介质,通过引入映射网络与生成网络进行对抗训练,同时提高二者的鲁棒性和准确性,进而有效的提高图像生成的鲁棒性和质量,并利用映射网络的映射功能将高维数据映射到低维数据,以低维数据来度量真实样本和生成样本之间的距离,在大大降低了计算量和节约内存的同时提高了生成网络对高维数据的处理能力以及生成高维图像的性能。并且本申请中采用最大均值差异来评估真实样本和生成样本之间的差异,评估效果更优良,从而对生成网络的训练更精准。
上述实施例仅例示性说明本申请的原理及其功效,而非用于限制本申请。任何熟悉此技术的人士皆可在不违背本申请的精神及范畴下,对上述实施例进行修饰或改变。因此,举凡所属技术领域中具有通常知识者在未脱离本申请所揭示的精神与技术思想下所完成的一切等效修饰或改变,仍应由本申请的权利要求所涵盖。

Claims (28)

1.一种生成网络的训练方法,其特征在于,包括以下步骤:
获取真实样本和生成样本,所述生成样本是由使用第一参数的生成网络所生成的;
利用所述真实样本和所述生成样本对映射网络的第二参数进行训练,直至所训练的映射网络使得所述真实样本和所述生成样本之间的特征距离符合最大化条件;
利用训练后的映射网络对所述生成网络的第一参数进行训练,直至训练后的第二参数的条件下的所述真实样本和训练后的生成网络输出的生成样本之间的特征距离符合最小化条件;
重复上述过程直至经训练得到的第一参数满足预设收敛条件。
2.根据权利要求1所述的生成网络的训练方法,其特征在于,所述映射网络用于将所述真实样本和生成样本映射到特征表示空间内。
3.根据权利要求2所述的生成网络的训练方法,其特征在于,所述映射网络的映射包括:降维操作。
4.根据权利要求2所述的生成网络的训练方法,其特征在于,所述特征距离为经所述映射网络映射到特征表示空间内的真实样本的特征和生成样本的特征之间的最大均值差异。
5.根据权利要求1所述的生成网络的训练方法,其特征在于,所述利用所述真实样本和所述生成样本对映射网络的第二参数进行训练包括:
计算经所述映射网络得到的真实样本和生成样本之间的特征距离;
基于所述特征距离确定所述第二参数的偏移量,以用于更新所述第二参数;
利用更新后的第二参数重复上述步骤,直至所得到的所述特征距离符合所述最大化条件。
6.根据权利要求5所述的生成网络的训练方法,其特征在于,所述第二参数的偏移量包括:基于所述特征距离在第二参数上的梯度而得到的偏移量。
7.根据权利要求1所述的生成网络的训练方法,其特征在于,所述利用训练后的映射网络对所述生成网络的第一参数进行训练的步骤包括:
在训练后的第二参数的条件下,基于所述真实样本和每次更新第一参数后生成网络输出的生成样本之间的特征距离而确定的所述第一参数的偏移量,更新所述生成网络的第一参数,直至所述真实样本和训练后的生成网络输出的生成样本之间的特征距离符合所述最小化条件。
8.根据根据权利要求7所述的生成网络的训练方法,其特征在于,所述第一参数的偏移量包括:每次更新第一参数后生成网络输出的生成样本和所述真实样本之间的特征距离在第一参数上的梯度而得到的偏移量。
9.根据权利要求1所述的生成网络的训练方法,其特征在于,所述训练方法还包括获取一随机变量的步骤,所述生成样本即为所述生成网络基于所述随机变量生成的。
10.根据权利要求9所述的生成网络的训练方法,其特征在于,所述随机变量遵循于一正态分布或一均匀分布。
11.根据权利要求1所述的生成网络的训练方法,其特征在于,所述真实样本为参考图像,所述生成样本为生成网络输出的生成图像。
12.一种生成网络的训练系统,其特征在于,包括:
映射网络模块,具有第二参数,用于获取真实样本和生成样本,并基于第二参数输出映射后的真实样本和生成样本;
生成网络模块,具有第一参数,用于基于第一参数输出生成样本;
训练模块,用于基于第一参数对所述映射网络模块的第二参数训练,直至所训练的映射网络模块使得所述真实样本和使用第一参数的生成网络模块输出的生成样本之间的特征距离符合最大化条件,以及用于基于训练后的映射网络模块对所述生成网络模块的第一参数训练,直至训练后的第二参数的条件下的所述真实样本和训练后的生成网络模块输出的生成样本之间的特征距离符合最小化条件;
所述训练模块还用于重复上述训练过程,直至经训练的第一参数满足预设收敛条件。
13.根据权利要求12所述的生成网络的训练系统,其特征在于,所述映射网络模块包括:特征表示单元,用于将所述真实样本和生成样本映射到特征表示空间内。
14.根据权利要求13所述的生成网络的训练系统,其特征在于,所述特征表示单元将所述真实样本和生成样本映射到特征表示空间内包括:对真实样本的特征和生成样本的特征进行降维处理。
15.根据权利要求12所述的生成网络的训练系统,其特征在于,所述特征距离为经所述映射网络模块映射到特征表示空间内的真实样本的特征和生成样本的特征之间的最大均值差异。
16.根据权利要求12所述的生成网络的训练系统,其特征在于,所述训练模块还包括用于计算所述特征距离的计算单元。
17.根据权利要求16所述的生成网络的训练系统,其特征在于,所述训练模块还包括:
第二更新单元,用于基于所述真实样本和所述生成样本之间的特征距离而确定的所述第二参数的偏移量,更新所述第二参数;
第二判别单元,用于在判断利用更新后的第二参数所得到的特征距离符合最大化条件时停止第二更新单元的更新。
18.根据权利要求17所述的生成网络的训练系统,其特征在于,所述第二参数的偏移量包括:基于所述真实样本和所述生成样本之间的特征距离在第二参数上的梯度变化而得到的偏移量。
19.根据权利要求16所述的生成网络的训练系统,其特征在于,所述训练模块还包括:
第一更新单元,用于在训练后的第二参数的条件下,基于所述真实样本和每次更新第一参数后生成网络模块输出的生成样本之间的特征距离而确定的所述第一参数的偏移量,更新所述生成网络模块的第一参数;
第一判别单元,用于在判断基于每次更新后的第一参数而输出的生成样本与所述真实样本之间的特征距离符合最小化条件时停止第一更新单元的更新。
20.根据权利要求19所述的生成网络的训练系统,其特征在于,所述第一参数的偏移量包括:基于所述真实样本和每次更新第一参数后生成网络输出的生成样本之间的特征距离在第一参数上的梯度变化而得到的偏移量。
21.根据权利要求12所述的生成网络的训练系统,其特征在于,所述生成网络模块还用于获取一随机变量,所述生成样本即为所述生成网络模块基于所述随机变量生成的。
22.根据权利要求21所述的生成网络的训练系统,其特征在于,所述随机变量遵循于一正态分布或一均匀分布。
23.根据权利要求12所述的生成网络的训练系统,其特征在于,所述真实样本为参考图像,所述生成样本为所述生成网络模块生成的图像。
24.一种图像生成方法,其特征在于,包括以下步骤:
获取一原始输入图像;
利用上述权利要求1至11任一项所述的生成网络训练方法获得的生成网络对所述原始输入图像进行处理以输出一生成图像。
25.一种图像生成设备,其特征在于,包括:
图像采集装置,用于获取待处理的原始输入图像;
存储装置,用于存储计算机程序;
处理装置,通信连接所述图像采集装置及存储装置,用于运行所述计算机程序来执行如权利要求24所述的图像生成方法。
26.一种图像生成客户端,其特征在于,装载于一智能设备;所述客户端包括:
输入模块,用于接收到生成指令时,调用所述智能设备的图像采集装置获取待处理的原始输入图像;
处理模块,用于执行如权利要求24所述的图像生成方法,以获得生成图像。
27.根据权利要求26所述的图像生成客户端,其特征在于,所述客户端还包括一显示模块,用于显示所述生成图像。
28.一种计算机可读存储介质,其特征在于,存储有至少一计算机程序,所述计算机程序被用于执行实现权利要求1至11任一所述的生成网络的训练方法或执行实现权利要求24所述的图像生成的方法。
CN202010152216.5A 2020-03-06 2020-03-06 生成网络的训练方法和系统、以及图像生成方法及设备 Active CN111062468B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010152216.5A CN111062468B (zh) 2020-03-06 2020-03-06 生成网络的训练方法和系统、以及图像生成方法及设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010152216.5A CN111062468B (zh) 2020-03-06 2020-03-06 生成网络的训练方法和系统、以及图像生成方法及设备

Publications (2)

Publication Number Publication Date
CN111062468A true CN111062468A (zh) 2020-04-24
CN111062468B CN111062468B (zh) 2023-06-20

Family

ID=70307890

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010152216.5A Active CN111062468B (zh) 2020-03-06 2020-03-06 生成网络的训练方法和系统、以及图像生成方法及设备

Country Status (1)

Country Link
CN (1) CN111062468B (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111652064A (zh) * 2020-04-30 2020-09-11 平安科技(深圳)有限公司 人脸图像生成方法、电子装置及可读存储介质
CN111652064B (zh) * 2020-04-30 2024-06-07 平安科技(深圳)有限公司 人脸图像生成方法、电子装置及可读存储介质

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110443293A (zh) * 2019-07-25 2019-11-12 天津大学 基于双判别生成对抗网络文本重构的零样本图像分类方法
CN110827201A (zh) * 2019-11-05 2020-02-21 广东三维家信息科技有限公司 用于高动态范围图像超分辨率重建的生成式对抗网络训练方法及装置
CN110853012A (zh) * 2019-11-11 2020-02-28 苏州锐一仪器科技有限公司 获得心脏参数的方法、装置及计算机存储介质

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110443293A (zh) * 2019-07-25 2019-11-12 天津大学 基于双判别生成对抗网络文本重构的零样本图像分类方法
CN110827201A (zh) * 2019-11-05 2020-02-21 广东三维家信息科技有限公司 用于高动态范围图像超分辨率重建的生成式对抗网络训练方法及装置
CN110853012A (zh) * 2019-11-11 2020-02-28 苏州锐一仪器科技有限公司 获得心脏参数的方法、装置及计算机存储介质

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111652064A (zh) * 2020-04-30 2020-09-11 平安科技(深圳)有限公司 人脸图像生成方法、电子装置及可读存储介质
CN111652064B (zh) * 2020-04-30 2024-06-07 平安科技(深圳)有限公司 人脸图像生成方法、电子装置及可读存储介质

Also Published As

Publication number Publication date
CN111062468B (zh) 2023-06-20

Similar Documents

Publication Publication Date Title
CN107636690B (zh) 基于卷积神经网络的全参考图像质量评估
CN110991652A (zh) 神经网络模型训练方法、装置及电子设备
CN108197652B (zh) 用于生成信息的方法和装置
CN111860573A (zh) 模型训练方法、图像类别检测方法、装置和电子设备
CN112634170B (zh) 一种模糊图像修正的方法、装置、计算机设备及存储介质
CN111950723A (zh) 神经网络模型训练方法、图像处理方法、装置及终端设备
WO2023040510A1 (zh) 图像异常检测模型训练方法、图像异常检测方法和装置
CN111695421B (zh) 图像识别方法、装置及电子设备
CN111738243A (zh) 人脸图像的选择方法、装置、设备及存储介质
CN111242217A (zh) 图像识别模型的训练方法、装置、电子设备及存储介质
CN111667420B (zh) 图像处理方法及装置
CN109214501B (zh) 用于识别信息的方法和装置
CN114925748B (zh) 模型训练及模态信息的预测方法、相关装置、设备、介质
CN107590460A (zh) 人脸分类方法、装置及智能终端
CN112613543A (zh) 增强策略验证方法、装置、电子设备及存储介质
CN115205736A (zh) 视频数据的识别方法和装置、电子设备和存储介质
CN111382791A (zh) 深度学习任务处理方法、图像识别任务处理方法和装置
CN111242176A (zh) 计算机视觉任务的处理方法、装置及电子系统
CN113516697A (zh) 图像配准的方法、装置、电子设备及计算机可读存储介质
CN113902944A (zh) 模型的训练及场景识别方法、装置、设备及介质
CN111062468A (zh) 生成网络的训练方法和系统、以及图像生成方法及设备
TWI818496B (zh) 指紋識別方法、指紋模組及電子設備
CN110008907B (zh) 一种年龄的估计方法、装置、电子设备和计算机可读介质
CN114241044A (zh) 回环检测方法、装置、电子设备和计算机可读介质
CN112580689A (zh) 神经网络模型的训练方法、应用方法、装置和电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant
TR01 Transfer of patent right
TR01 Transfer of patent right

Effective date of registration: 20231116

Address after: Building C21, 6th Floor, Zidong International Creative Park, No. 2 Zidong Road, Maqun Street, Nanjing City, Jiangsu Province, 210046

Patentee after: TULING ARTIFICIAL INTELLIGENCE INSTITUTE (NANJING) Co.,Ltd.

Patentee after: TSINGHUA University

Address before: 210046 601 room, No. 6, Qi Min Road, Xianlin street, Qixia District, Nanjing, Jiangsu, China. 6

Patentee before: TULING ARTIFICIAL INTELLIGENCE INSTITUTE (NANJING) Co.,Ltd.