CN115564024B - 生成网络的特征蒸馏方法、装置、电子设备及存储介质 - Google Patents

生成网络的特征蒸馏方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN115564024B
CN115564024B CN202211242759.1A CN202211242759A CN115564024B CN 115564024 B CN115564024 B CN 115564024B CN 202211242759 A CN202211242759 A CN 202211242759A CN 115564024 B CN115564024 B CN 115564024B
Authority
CN
China
Prior art keywords
image
preset
loss
network
distillation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202211242759.1A
Other languages
English (en)
Other versions
CN115564024A (zh
Inventor
季向阳
杨宇
程笑天
刘畅
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tsinghua University
Original Assignee
Tsinghua University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tsinghua University filed Critical Tsinghua University
Priority to CN202211242759.1A priority Critical patent/CN115564024B/zh
Publication of CN115564024A publication Critical patent/CN115564024A/zh
Application granted granted Critical
Publication of CN115564024B publication Critical patent/CN115564024B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本申请涉及计算机视觉与深度学习技术领域,特别涉及一种生成网络的特征蒸馏方法、装置、电子设备及存储介质,其中,包括:获取目标生成网络中的多个特征图;将其输入至预设挤压模块,并从中挤压出满足预设变换不变性的预设图像特征;从预设数据增广中随机抽样出图像变换算子,并利用图像变换算子对预设图像特征进行特征蒸馏,得到目标生成网络在合成图像领域的图像;并将其输入至预设学生网络,同时输入真实图像进行自监督对比学习,使得目标生成网络的蒸馏表征扩张至真实图像领域,实现目标生成网络的特征蒸馏。由此,解决了无法从GAN的生成器中蒸馏出价值的表征信息,无法充分利用或将表征迁移至下游任务,降低了表征提取网络的性能等问题。

Description

生成网络的特征蒸馏方法、装置、电子设备及存储介质
技术领域
本申请涉及数字图像处理,模式识别及计算机视觉与深度学习技术领域,特别涉及一种生成网络的特征蒸馏方法、装置、电子设备及存储介质。
背景技术
归功于大型数据集和网络架构设计的最新进展,GAN(Generative AdversarialNetworks,生成式对抗网络)不断取得令人印象深刻的图像合成结果,GAN不仅能合成逼真的图像,还能对其中的内容或风格进行操控,这些特性促使大量的工作利用预训练GAN来完成各种计算机视觉任务,包括部分分割,三维重建,图像对齐等,并展示GAN在人工标注不足时的优势,GAN可以产生细粒度的、解耦的以及可解释的表征,因而带来数据效率与泛化性能上的优势。
相关技术中,基于GAN的表征学习的工作侧重于利用判别器网络的特征或是通过训练一个编码器网络将图像映射回隐空间并利用其中学习到的特征,GAN的生成器所具有的可控合成特性表明其拥有信息丰富的、解耦的以及可解释的图像表征,然而利用或将表征迁移至下游任务仍然没有得到充分的探索,与判别器网络或编码器网络不同,生成器网络输入为隐变量,不能接受图像作为输入并输出表征或预测结果,难以直接迁移到其他下游任务。
发明内容
本申请提供一种生成网络的特征蒸馏方法、装置、电子设备及存储介质,以解决相关技术中无法从GAN的生成器中蒸馏出价值的表征信息,无法充分利用或将表征迁移至下游任务,降低了表征提取网络的性能等问题。
本申请第一方面实施例提供一种生成网络的特征蒸馏方法,包括以下步骤:获取目标生成网络中的多个特征图;将所述多个特征图输入至预设挤压模块,从所述多个特征图中挤压出满足预设变换不变性的预设图像特征;从预设数据增广中随机抽样出图像变换算子,并利用所述图像变换算子对所述预设图像特征进行特征蒸馏,得到所述目标生成网络在合成图像领域的图像;将所述目标生成网络在所述合成图像领域的图像输入至预设学生网络,同时输入真实图像进行自监督对比学习,使得所述目标生成网络的蒸馏表征扩张至真实图像领域,实现所述目标生成网络的特征蒸馏。
可选地,所述利用所述图像变换算子对所述预设图像特征进行特征蒸馏,得到所述目标生成网络在合成图像领域的图像蒸馏表征,还包括:利用所述图像变换算子对所述预设图像特征进行图像保持语义不变的特征蒸馏,并获取蒸馏过程中的蒸馏损失;在所述蒸馏损失中加入正则化项,对每个预设图像特征进行正则化,使得所述每个预设图像特征在每个维度上的变化程度大于预设变化程度,且去除所述每个维度之间的相关性,得到目标生成网络在合成图像领域的图像蒸馏表征。
可选地,所述预设学生网络的训练,包括:获取预设学生网络的训练数据;其中,所述训练数据包括合成数据与真实数据;利用预设的将生成器中的表征挤压到学生网络中的损失函数计算所述合成数据的损失;对于所述真实数据中每个训练图像利用随机数据增广进行多次变换,得到第一视图和第二视图;将所述第一视图和所述第二视图输入待训练的预设学生网络,输出所述第一视图的第一表征和所述第二视图的第二表征,根据所述第一表征和所述第二表征计算所述真实数据的损失;根据所述合成数据的损失和所述真实数据的损失计算每步训练迭代的总体损失,并通过训练迭代得到所述预设学生网络。
可选地,所述根据所述第一表征和所述第二表征计算得到训练损失值,包括:根据所述第一表征和所述第二表征计算得到真实数据的损失值;根据所述训练数据中合成数据的损失值与所述真实数据的损失值计算得到所述训练损失值。
本申请第二方面实施例提供一种生成网络的特征蒸馏装置,包括:获取模块,用于获取目标生成网络中的多个特征图;挤压模块,用于将所述多个特征图输入至预设挤压模块,从所述多个特征图中挤压出满足预设变换不变性的预设图像特征;处理模块,用于从预设数据增广中随机抽样出图像变换算子,并利用所述图像变换算子对所述预设图像特征进行特征蒸馏,得到所述目标生成网络在合成图像领域的图像;生成模块,用于将所述目标生成网络在所述合成图像领域的图像输入至预设学生网络,同时输入真实图像进行自监督对比学习,使得所述目标生成网络的蒸馏表征扩张至真实图像领域,实现所述目标生成网络的特征蒸馏。
可选地,所述处理模块用于:利用所述图像变换算子对所述预设图像特征进行图像保持语义不变的特征蒸馏,并获取蒸馏过程中的蒸馏损失;在所述蒸馏损失中加入正则化项,对每个预设图像特征进行正则化,使得所述每个预设图像特征在每个维度上的变化程度大于预设变化程度,且去除所述每个维度之间的相关性,得到目标生成网络在合成图像领域的图像蒸馏表征。
可选地,所述生成模块用于:获取预设学生网络的训练数据,其中,所述训练数据包括合成数据与真实数据;利用预设的将生成器中的表征挤压到学生网络中的损失函数计算所述合成数据的损失;对于所述真实数据中每个训练图像利用随机数据增广进行多次变换,得到第一视图和第二视图;将所述第一视图和所述第二视图输入待训练的预设学生网络,输出所述第一视图的第一表征和所述第二视图的第二表征,根据所述第一表征和所述第二表征计算真实数据的损失;根据所述合成数据的损失和所述真实数据的损失计算每步训练迭代的总体损失,并通过训练迭代得到所述预设学生网络。
可选地,所述生成模块进一步用于:根据第一表征和第二表征计算得到真实数据的损失值;根据训练数据中合成数据的损失值与所述真实数据的损失值计算得到所述训练损失值。
本申请第三方面实施例提供一种电子设备,包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述程序,以实现如上述实施例所述的生成网络的特征蒸馏方法。
本申请第四方面实施例提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行,以用于实现如上述实施例所述的生成网络的特征蒸馏方法。
由此,本申请至少具有如下有益效果:
本申请实施例挤压与扩张的方式将GAN生成器的表征知识蒸馏出来,将生成器的特征挤压成对通过网络进行语义保全转换不变的表示,然后再将其提炼到学生网络中,使用真实的训练数据将合成领域的蒸馏表征跨越到真实领域,以弥补GANs的模式崩溃,并提高学生网络在真实领域的性能,并获取一个较高性能的表征提取网络,对图像提取迁移性强的表征,能够将生成网络中有价值的表征信息蒸馏到表征网络中,并在下游任务中表现优异。
本申请附加的方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本申请的实践了解到。
附图说明
本申请上述的和/或附加的方面和优点从下面结合附图对实施例的描述中将变得明显和容易理解,其中:
图1为根据本申请实施例的生成网络的特征蒸馏方法的流程图;
图2为根据本申请实施例的GAN的生成器中挤压与扩张表征示意图;
图3为根据本申请实施例的生成网络的特征蒸馏装置的方框示意图;
图4为根据本申请实施例的电子设备的结构示意图。
具体实施方式
下面详细描述本申请的实施例,所述实施例的示例在附图中示出,其中自始至终相同或类似的标号表示相同或类似的元件或具有相同或类似功能的元件。下面通过参考附图描述的实施例是示例性的,旨在用于解释本申请,而不能理解为对本申请的限制。
通常GAN生成器将低分辨率(如4×4)的特征图转化为高分辨率(如256×256),并从最终的特征图中进一步合成图像或多尺度特征图,形式上,令G=g(L)·g(L-1)···g(1)表示L个模块串联的生成器,给定一个从先验分布中采样的隐变量w~P(w),将每个生成器模块输出的特征图的平均池化向量拼接,得到生成器表征:
其中,/>
然而,由于原始的GAN并不提供一个准确的由图像到隐空间的逆向模型,对于任何给定的图像,提取生成器特征仍然是不方便的。本申请实施例则可以从GAN生成器中蒸馏出有价值的特征。
下面参考附图描述本申请实施例的生成网络的特征蒸馏方法、装置、电子设备及存储介质。具体而言,图1为本申请实施例所提供的一种生成网络的特征蒸馏方法的流程示意图。
如图1所示,该生成网络的特征蒸馏方法包括以下步骤:
在步骤S101中,获取目标生成网络中的多个特征图。
可以理解的是,本申请实施例通过获取目标生成网络中特征图,为下一步对其挤压做准备。
在步骤S102中,将多个特征图输入至预设挤压模块,从多个特征图中挤压出满足预设变换不变性的预设图像特征。
其中,预设挤压模块可以是用户设置的模块,例如:引入一个挤压模块Tφ,将有信息的表征从生成器表示中挤出,对每个合成块的特征图进行平均池化,在此不做具体限定。
其中,预设变换不变性可以是挤压后的特征图与挤压前的特征图仅发生一定变换但性质不会发生改变,在此不做具体限定。
其中,预设图像特征可以是通过GAN生成器将低分辨率(如4×4)的特征图转化为高分辨率(如256×256),并从最终的特征图中进一步合成图像或多尺度特征图,在此不做具体限定。
可以理解的是,本申请实施例将获取到的多个特征图输入至挤压模块,并从中挤压出满足条件的图像特征,便于后续对图像特征进行特征蒸馏,降低工作量。
在步骤S103中,从预设数据增广中随机抽样出图像变换算子,并利用图像变换算子对预设图像特征进行特征蒸馏,得到目标生成网络在合成图像领域的图像。
其中,预设数据增广可以用于增加训练数据集,让数据集尽可能的多样化,使得训练的模型具有更强的泛化能力,在此不做具体限定。
其中,图像变换算子可以从几何或光度领域对图像特征进行改变,反映了图像上具有的相关特性,在此不做具体限定。
可以理解的是,本申请实施例从表征数据中随机抽样,并利用图像变换算子对图像特征进行特征蒸馏,得到合成图像领域的图像,对合成图像引入数据增广,能够提高表征提取网络的性能,并且使其对图像提取的迁移性增强。
在本申请实施例中,利用图像变换算子对预设图像特征进行特征蒸馏,得到目标生成网络在合成图像领域的图像蒸馏表征,还包括:利用图像变换算子对预设图像特征进行图像保持语义不变的特征蒸馏,并获取蒸馏过程中的蒸馏损失;在蒸馏损失中加入正则化项,对每个预设图像特征进行正则化,使得每个预设图像特征在每个维度上的变化程度大于预设变化程度,且去除每个维度之间的相关性,得到目标生成网络在合成图像领域的图像蒸馏表征。
其中,正则化的目的是限制模型的参数过多或者过大,避免模型太过复杂,可以在一定程度上抑制过拟合,让模型获得抗噪声的能力,在此不做具体限定。
其中,预设变化程度可以是用户事先设置的,可根据具体情况进行调整,在此不做具体限定。
可以理解的是,本申请实施例利用图像变换算子能够对图像特征的语义不变表征蒸馏出来,并获取蒸馏过程中的损失,并对蒸馏损失加入正则化项防止退化解的产生,同时采用方差-协方差来对表征进行正则化,使其在每个维度上都有明显的变化以及维度间去相关,来约束表征网络的输出,能够将生成网络中有价值的表征信息蒸馏到表征网络中,获取一个较高性能的表征提取网络,且对图像提取迁移性强的表征。
在步骤S104中,将目标生成网络在合成图像领域的图像输入至预设学生网络,同时输入真实图像进行自监督对比学习,使得目标生成网络的蒸馏表征扩张至真实图像领域,实现目标生成网络的特征蒸馏。
其中,预设学生网络可以是用户事先设置的网络,用于将真实数据引入到学生网络做训练使用,在此不做具体限定。
其中,合成图像领域可以是低质量的合成的图像领域,在此不做具体限定。
可以理解的是,本申请实施例通过将在合成图像领域的图像输入至学生网络中,得到真实图像领域的图像蒸馏表征,可以有效防止GAN的模式坍塌以及缓解合成域和真实域之间的问题,并在下游从真实图像中提取特征时中表现优异。
在本申请实施例中,预设学生网络的训练,包括:获取预设学生网络的训练数据;其中,训练数据包括合成数据与真实数据;利用预设的将生成器中的表征挤压到学生网络中的损失函数计算合成数据的损失;对于真实数据中每个训练图像利用随机数据增广进行多次变换,得到第一视图和第二视图;将第一视图和第二视图输入待训练的预设学生网络,输出第一视图的第一表征和第二视图的第二表征,根据第一表征和第二表征计算真实数据的损失;根据合成数据的损失和真实数据的损失计算每步训练迭代的总体损失,并通过训练迭代得到预设学生网络。
其中,第一视图和第二视图可以是用户事先设置的,例如:将两个视图设为和/>在此不做具体限定。
其中,第一表征和第二表征可以是用户事先设置的,例如:将两个表征设为Zr和Zr′,在此不做具体限定。
可以理解的是,本申请实施例获取真实的训练数据,并对其中每个训练图像利用随机数据增广进行多次变化,将变换后得到的视图输入至待训练的学生网络中,得到相应表征和训练损失值,计算总体损失并通过训练迭代得到学生网络,使其在下游任务中从真实图像中提取特征时表现优异,并提高学生网络在真实领域的性能。
在本申请实施例中,根据第一表征和第二表征计算得到训练损失值,包括:根据第一表征和第二表征计算得到真实数据的损失值;根据训练数据中合成数据的损失值与真实数据的损失值计算得到训练损失值。
可以理解的是,本申请实施例通过第一表征和第二表征计算真实数据的损失值,并根据训练数据中合成数据的损失值与真实数据的损失值计算得到训练的总损失值,使得到的数据更加准确,使用户可以更直观的看到相应的损失做出相应的操作。
根据本申请实施例提出的生成网络的特征蒸馏方法,通过对获取的多个图像信息特征进行挤压,使其挤压成有价值的图像特征然后随机抽样,并利用图像变换算子对有价值的图像特征进行蒸馏,然后再将其提炼到学生网络中,使用真实的训练数据将合成领域的蒸馏表征跨越到真实领域,以弥补GAN的模式崩溃,并提高学生网络在真实领域的性能,可以提取出有价值的表征信息,并获取一个较高性能的表征提取网络,对图像提取迁移性强的表征。由此,解决了相关技术中无法从GAN的生成器中蒸馏出价值的表征信息,无法充分利用或将表征迁移至下游任务,降低了表征提取网络的性能等问题。
下面将结合图2对生成网络的特征蒸馏方法进行详细阐述,具体如下:
如图2所示,在左半图中,预训练的生成器G和挤压模块Tθ构成教师网络,产生挤压表征,并被蒸馏至学生网络Sθ(挤压部分),于此同时学生网络也在真实数据上进行训练(扩张部分)。右半图展示了生成器与挤压模块的结构,以合成32×32分辨率的图像的StyleGAN2生成器为例说明,挤压模块对每个合成块的特征图进行平均池化(用μ表示),并用线性层加MLP(Multilayer Perceptron,多层感知器)对其进行转换。
(1)为了缓解生成器表征可能包含太多与下游任务无关的信息,本申请实施例引入了一个挤压模块Tθ(如图2所示),将有信息的表征从生成器表示中挤出;此外,本申请实施例在将生成的图像送入学生网络之前,用保持语义不变的图像变换a(例如颜色扰动或裁剪)对其进行变换,图像变换公式为:
其中,图像变换a是从A中随机抽样的,换句话说,本申请实施例试图从生成器中蒸馏出对数据增广A不变的紧凑表征。
然而,与自监督学习中的孪生网络类似,存在一个对公式(2)的平凡解:挤压模块和学生网络退化为对任何输入都输出常数。
因此,本申请实施例在蒸馏损失中加入了正则化项防止退化解的产生,特别地,本申请实施例采用方差-协方差来对表征进行正则化,使其在每个维度上都有明显的变化以及维度间去相关。形式上,在一个由N个样本组成的小批数据中,挤压的生成器表征与学生网络产生的表征分别表示为
其中wi~P(w)和ai~A分别表示隐变量的随机采样与数据增广算子。
方差损失鼓励每个表示维度的标准差大于1,具体公式如下:
其中zj表示z中的第j个维度。协方差损失鼓励任意一对维度不相关,
综上,将生成器中的表征挤压到学生网络中的损失函数可以概括为
其中λ、μ、ν分别用于调节各项损失的权重。
(2)为了缓解真实图像和合成图像之间存在一定的领域差异,使其在下游任务中从真实图像中提取特征时很可能表现不佳,导致这种原因是由于合成域和真实域之间的问题,其中,合成的图像可能是低质量的,这方面在近期的GAN建模中得到了很大的改善;更重要的是,GAN的一个顽疾是模式坍塌,即合成数据只能覆盖真实数据分布的部分模式。
而本申请实施例为了缓解模式坍塌的危害,将真实数据引入到学生网络的训练数据中,特别地,在每步训练迭代中,一小批训练数据由合成数据与真实数据组成,对于合成数据,采用前述的挤压损失;对于真实数据,本申请实施例采用原始的方差-不变-协方差正则方法来计算损失。
具体来说,给定一小批真实数据每个图像/>用随机数据增广进行两次变换,得到两个视图/>和/>其中ai,ai′~A,与公式(3)类似,两个视图输入Sθ分别得到相应的表征Zr和Zr′。然后,真实数据的损失被计算为
其中表示为通过衡量真实图像上两视图表表征距离的自我蒸馏损失。总体损失是通过简单地组合生成数据损失与真实数据损失来计算的,即Ltoral=αLsqueeze+(1-α)Lspan,其中α=0.5表示合成数据在小批训练样本中的比例。
综上,提出了一个由MLP实现的挤压模块,并且对合成图像引入数据增广,使得挤压模块能将生成器网络的语义不变表征挤压出来,由于直接对挤压模块与表征网络联合优化会导致平凡解,如输出恒为零向量的表征,本申请实施例采用了方差-协方差正则化方法来约束表征网络的输出;若仅将合成图像的生成器表征蒸馏出来,合成图像和真实图像之间的存在的领域差异将导致表征网络难以适应真实图像的表征提取;而本申请实施例通过将额外的真实图像引入训练过程,使得生成网络的表征可扩张至真实图像域,能够将生成网络中有价值的表征信息蒸馏到表征网络中,并在下游任务中表现优异。
其次参照附图描述根据本申请实施例提出的生成网络的特征蒸馏装置。
图3是本申请实施例的生成网络的特征蒸馏装置的方框示意图。
如图3所示,该生成网络的特征蒸馏装置10包括:获取模块100、挤压模块200、处理模块300和生成模块400。
其中,获取模块100用于获取目标生成网络中的多个特征图;挤压模块200用于将多个特征图输入至预设挤压模块,从多个特征图中挤压出满足预设变换不变性的预设图像特征;处理模块300用于从预设数据增广中随机抽样出图像变换算子,并利用图像变换算子对预设图像特征进行特征蒸馏,得到目标生成网络在合成图像领域的图像;生成模块400用于将目标生成网络在合成图像领域的图像输入至预设学生网络,同时输入真实图像进行自监督对比学习,使得目标生成网络的蒸馏表征扩张至真实图像领域,实现目标生成网络的特征蒸馏。
在本申请实施例中,处理模块300用于:利用图像变换算子对预设图像特征进行图像保持语义不变的特征蒸馏,并获取蒸馏过程中的蒸馏损失;在蒸馏损失中加入正则化项,对每个预设图像特征进行正则化,使得每个预设图像特征在每个维度上的变化程度大于预设变化程度,且去除每个维度之间的相关性,得到目标生成网络在合成图像领域的图像蒸馏表征。
在本申请实施例中,生成模块400用于:获取预设学生网络的训练数据;其中,训练数据包括合成数据与真实数据;利用预设的将生成器中的表征挤压到学生网络中的损失函数计算合成数据的损失;对于真实数据中每个训练图像利用随机数据增广进行多次变换,得到第一视图和第二视图;将第一视图和第二视图输入待训练的预设学生网络,输出第一视图的第一表征和第二视图的第二表征,根据第一表征和第二表征计算真实数据的损失;根据合成数据的损失和真实数据的损失计算每步训练迭代的总体损失,并通过训练迭代得到预设学生网络。
在本申请实施例中,生成模块400进一步用于:根据第一表征和第二表征计算得到真实数据的损失值;根据训练数据中合成数据的损失值与真实数据的损失值计算得到训练损失值。
需要说明的是,前述对生成网络的特征蒸馏方法实施例的解释说明也适用于该实施例的生成网络的特征蒸馏装置,此处不再赘述。
根据本申请实施例提出的生成网络的特征蒸馏装置,通过对获取的多个图像信息特征进行挤压,使其挤压成有价值的图像特征然后随机抽样,并利用图像变换算子对有价值的图像特征进行蒸馏,然后再将其提炼到学生网络中,使用真实的训练数据将合成领域的蒸馏表征跨越到真实领域,以弥补GANs的模式崩溃,并提高学生网络在真实领域的性能,可以提取出有价值的表征信息,并获取一个较高性能的表征提取网络,对图像提取迁移性强的表征。由此,解决了相关技术中无法从GAN的生成器中蒸馏出价值的表征信息,无法充分利用或将表征迁移至下游任务,降低了表征提取网络的性能等问题。
图4为本申请实施例提供的电子设备的结构示意图。该电子设备可以包括:
存储器401、处理器402及存储在存储器401上并可在处理器402上运行的计算机程序。
处理器402执行程序时实现上述实施例中提供的生成网络的特征蒸馏方法。
进一步地,电子设备还包括:
通信接口403,用于存储器401和处理器402之间的通信。
存储器401,用于存放可在处理器402上运行的计算机程序。
存储器401可能包含高速RAM(Random Access Memory,随机存取存储器)存储器,也可能还包括非易失性存储器,例如至少一个磁盘存储器。
如果存储器401、处理器402和通信接口403独立实现,则通信接口403、存储器401和处理器402可以通过总线相互连接并完成相互间的通信。总线可以是ISA(IndustryStandard Architecture,工业标准体系结构)总线、PCI(Peripheral Component,外部设备互连)总线或EISA(Extended Industry Standard Architecture,扩展工业标准体系结构)总线等。总线可以分为地址总线、数据总线、控制总线等。为便于表示,图4中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
可选的,在具体实现上,如果存储器401、处理器402及通信接口403,集成在一块芯片上实现,则存储器401、处理器402及通信接口403可以通过内部接口完成相互间的通信。
处理器402可能是一个CPU(Central Processing Unit,中央处理器),或者是ASIC(Application Specific Integrated Circuit,特定集成电路),或者是被配置成实施本申请实施例的一个或多个集成电路。
本申请实施例还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上的生成网络的特征蒸馏方法。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本申请的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不是必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或N个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括至少一个该特征。在本申请的描述中,“N个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。
流程图中或在此以其他方式描述的任何过程或方法描述可以被理解为,表示包括一个或更N个用于实现定制逻辑功能或过程的步骤的可执行指令的代码的模块、片段或部分,并且本申请的优选实施方式的范围包括另外的实现,其中可以不按所示出或讨论的顺序,包括根据所涉及的功能按基本同时的方式或按相反的顺序,来执行功能,这应被本申请的实施例所属技术领域的技术人员所理解。
应当理解,本申请的各部分可以用硬件、软件、固件或它们的组合来实现。在上述实施方式中,N个步骤或方法可以用存储在存储器中且由合适的指令执行系统执行的软件或固件来实现。如,如果用硬件来实现和在另一实施方式中一样,可用本领域公知的下列技术中的任一项或他们的组合来实现:具有用于对数据信号实现逻辑功能的逻辑门电路的离散逻辑电路,具有合适的组合逻辑门电路的专用集成电路,可编程门阵列,现场可编程门阵列等。
本技术领域的普通技术人员可以理解实现上述实施例方法携带的全部或部分步骤是可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,该程序在执行时,包括方法实施例的步骤之一或其组合。尽管上面已经示出和描述了本申请的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本申请的限制,本领域的普通技术人员在本申请的范围内可以对上述实施例进行变化、修改、替换和变型。

Claims (6)

1.一种生成网络的特征蒸馏方法,其特征在于,包括以下步骤:
获取目标生成网络中的多个特征图;
将所述多个特征图输入至预设挤压模块,从所述多个特征图中挤压出满足预设变换不变性的预设图像特征,其中,所述预设挤压模块用于对每个特征图进行平均池化,并用线性层和多层感知器对池化后的特征图进行转换;
从预设数据增广中随机抽样出图像变换算子,并利用所述图像变换算子对所述预设图像特征进行图像保持语义不变的特征蒸馏,并获取蒸馏过程中的蒸馏损失,在所述蒸馏损失中加入正则化项,对每个预设图像特征进行正则化,使得所述每个预设图像特征在每个维度上的变化程度大于预设变化程度,且去除所述每个维度之间的相关性,得到目标生成网络在合成图像领域的图像蒸馏表征,从而生成所述目标生成网络在合成图像领域的图像,其中,蒸馏损失是蒸馏学习过程中从目标生成网络到预设学生网络产生的图像特征的损耗,图像变换算子是从几何和/或光度领域对图像特征进行改变,反映了图像上具有的相关特性;
将所述目标生成网络在所述合成图像领域的图像输入至预设学生网络,同时输入真实图像进行自监督对比学习,使得所述目标生成网络的图像蒸馏表征扩张至真实图像领域,实现所述目标生成网络的特征蒸馏。
2.根据权利要求1所述的方法,其特征在于,预设学生网络的训练,包括:
获取预设学生网络的训练数据,其中,所述训练数据包括合成数据与真实数据;
利用预设的将生成器中的表征挤压到学生网络中的损失函数计算所述合成数据的损失,其中,损失函数为:
其中,Lsqueeze表示合成数据的损失,Zg表示挤压的生成器表征,λ、μ、ν分别用于调节各项损失的权重,
其中,Tφ表示挤压模块,Zj表示Z中的第j个维度,C(Z)表示表征的协方差矩阵;
对于所述真实数据中每个训练图像利用随机数据增广进行两次变换,得到第一视图和第二视图,将所述第一视图和所述第二视图输入待训练的预设学生网络,输出所述第一视图的第一表征和所述第二视图的第二表征,根据所述第一表征和所述第二表征计算所述真实数据的损失,其中,真实数据的损失计算公式为:其中,Lspan表示真实数据的损失,/>表示为通过衡量真实图像上两视图表表征距离的自我蒸馏损失,Zr和/>分别表示为第一表征和第二表征;
根据所述合成数据的损失和所述真实数据的损失计算每步训练迭代的总体损失,并通过训练迭代得到所述预设学生网络,其中,总体损失计算公式为:其中,α=0.5表示合成数据在小批训练样本中的比例,/>表示合成数据的损失,/>表示真实数据的损失。
3.一种生成网络的特征蒸馏装置,其特征在于,包括:
获取模块,用于获取目标生成网络中的多个特征图;
挤压模块,用于将所述多个特征图输入至预设挤压模块,从所述多个特征图中挤压出满足预设变换不变性的预设图像特征,其中,所述预设挤压模块用于对每个特征图进行平均池化,并用线性层和多层感知器对池化后的特征图进行转换;
处理模块,用于从预设数据增广中随机抽样出图像变换算子,并利用所述图像变换算子对所述预设图像特征进行图像保持语义不变的特征蒸馏,并获取蒸馏过程中的蒸馏损失,在所述蒸馏损失中加入正则化项,对每个预设图像特征进行正则化,使得所述每个预设图像特征在每个维度上的变化程度大于预设变化程度,且去除所述每个维度之间的相关性,得到目标生成网络在合成图像领域的图像蒸馏表征,从而生成所述目标生成网络在合成图像领域的图像,其中,蒸馏损失是蒸馏学习过程中从目标生成网络到预设学生网络产生的图像特征的损耗,图像变换算子是从几何和/或光度领域对图像特征进行改变,反映了图像上具有的相关特性;
生成模块,用于将所述目标生成网络在所述合成图像领域的图像输入至预设学生网络,同时输入真实图像进行自监督对比学习,使得所述目标生成网络的图像蒸馏表征扩张至真实图像领域,实现所述目标生成网络的特征蒸馏。
4.根据权利要求3所述的装置,其特征在于,所述生成模块用于:
获取预设学生网络的训练数据,其中,所述训练数据包括合成数据与真实数据;
利用预设的将生成器中的表征挤压到学生网络中的损失函数计算所述合成数据的损失,其中,损失函数为:
其中,Lsqueeze表示合成数据的损失,Zg表示挤压的生成器表征,λ、μ、v分别用于调节各项损失的权重,
其中,Tφ表示挤压模块,Zj表示Z中的第j个维度,其中,C(Z)表示表征的协方差矩阵;
对于所述真实数据中每个训练图像利用随机数据增广进行多次变换,得到第一视图和第二视图,将所述第一视图和所述第二视图输入待训练的预设学生网络,输出所述第一视图的第一表征和所述第二视图的第二表征,根据所述第一表征和所述第二表征计算所述真实数据的损失,其中,真实数据的损失计算公式为:Lspan=λL′RD+μ[Lvar(Zr)+Lvar(Z′r)]+ν[Lcov(Zr)+Lcov(Z′r)],其中,Lspan表示真实数据的损失,表示为通过衡量真实图像上两视图表表征距离的自我蒸馏损失,Zr和Z′r分别表示为第一表征和第二表征;
根据所述合成数据的损失和所述真实数据的损失计算每步训练迭代的总体损失,并通过训练迭代得到所述预设学生网络,其中,总体损失计算公式为:其中,α=0.5表示合成数据在小批训练样本中的比例,/>表示合成数据的损失,/>表示真实数据的损失。
5.一种电子设备,其特征在于,包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述程序,以实现如权利要求1-2任一项所述的生成网络的特征蒸馏方法。
6.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行,以用于实现如权利要求1-2任一项所述的生成网络的特征蒸馏方法。
CN202211242759.1A 2022-10-11 2022-10-11 生成网络的特征蒸馏方法、装置、电子设备及存储介质 Active CN115564024B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211242759.1A CN115564024B (zh) 2022-10-11 2022-10-11 生成网络的特征蒸馏方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211242759.1A CN115564024B (zh) 2022-10-11 2022-10-11 生成网络的特征蒸馏方法、装置、电子设备及存储介质

Publications (2)

Publication Number Publication Date
CN115564024A CN115564024A (zh) 2023-01-03
CN115564024B true CN115564024B (zh) 2023-09-15

Family

ID=84745627

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211242759.1A Active CN115564024B (zh) 2022-10-11 2022-10-11 生成网络的特征蒸馏方法、装置、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN115564024B (zh)

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111598216A (zh) * 2020-04-16 2020-08-28 北京百度网讯科技有限公司 学生网络模型的生成方法、装置、设备及存储介质
CN112465111A (zh) * 2020-11-17 2021-03-09 大连理工大学 一种基于知识蒸馏和对抗训练的三维体素图像分割方法
CN113065635A (zh) * 2021-02-27 2021-07-02 华为技术有限公司 一种模型的训练方法、图像增强方法及设备
CN113112020A (zh) * 2021-03-25 2021-07-13 厦门大学 一种基于生成网络与知识蒸馏的模型网络提取和压缩方法
CN113178255A (zh) * 2021-05-18 2021-07-27 西安邮电大学 一种基于gan的医学诊断模型对抗攻击方法
CN113449680A (zh) * 2021-07-15 2021-09-28 北京理工大学 一种基于知识蒸馏的多模小目标检测方法
CN113538334A (zh) * 2021-06-09 2021-10-22 香港中文大学深圳研究院 一种胶囊内窥镜图像病变识别装置及训练方法
CN114529622A (zh) * 2022-01-12 2022-05-24 华南理工大学 通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法及装置
CN115034983A (zh) * 2022-05-30 2022-09-09 国网四川省电力公司眉山供电公司 一种输变电设备图像数据增广方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11048974B2 (en) * 2019-05-06 2021-06-29 Agora Lab, Inc. Effective structure keeping for generative adversarial networks for single image super resolution

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111598216A (zh) * 2020-04-16 2020-08-28 北京百度网讯科技有限公司 学生网络模型的生成方法、装置、设备及存储介质
CN112465111A (zh) * 2020-11-17 2021-03-09 大连理工大学 一种基于知识蒸馏和对抗训练的三维体素图像分割方法
CN113065635A (zh) * 2021-02-27 2021-07-02 华为技术有限公司 一种模型的训练方法、图像增强方法及设备
CN113112020A (zh) * 2021-03-25 2021-07-13 厦门大学 一种基于生成网络与知识蒸馏的模型网络提取和压缩方法
CN113178255A (zh) * 2021-05-18 2021-07-27 西安邮电大学 一种基于gan的医学诊断模型对抗攻击方法
CN113538334A (zh) * 2021-06-09 2021-10-22 香港中文大学深圳研究院 一种胶囊内窥镜图像病变识别装置及训练方法
CN113449680A (zh) * 2021-07-15 2021-09-28 北京理工大学 一种基于知识蒸馏的多模小目标检测方法
CN114529622A (zh) * 2022-01-12 2022-05-24 华南理工大学 通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法及装置
CN115034983A (zh) * 2022-05-30 2022-09-09 国网四川省电力公司眉山供电公司 一种输变电设备图像数据增广方法

Non-Patent Citations (4)

* Cited by examiner, † Cited by third party
Title
Yu Yang et al..Using Generative Adversarial Networks Based on Dual Attention Mechanism to Generate Face Images.《2021 International Conference on Computer Technology and Media Convergence Design(CTMCD)》.2021,全文. *
周立君 等.一种基于GAN和自适应迁移学习的样本生成方法.应用光学.2020,(第01期),全文. *
王星 等.基于深度残差生成式对抗网络的样本生成方法.控制与决策.2020,(第08期),全文. *
黄永松.可见光图像与红外图像的超分辨率重建算法研究.《中国优秀硕士学位论文全文数据库 信息科技辑》.2022,(第undefined期),全文. *

Also Published As

Publication number Publication date
CN115564024A (zh) 2023-01-03

Similar Documents

Publication Publication Date Title
CN109493350B (zh) 人像分割方法及装置
CN110532996B (zh) 视频分类的方法、信息处理的方法以及服务器
US9501724B1 (en) Font recognition and font similarity learning using a deep neural network
CN110648334A (zh) 一种基于注意力机制的多特征循环卷积显著性目标检测方法
EP3963516B1 (en) Teaching gan (generative adversarial networks) to generate per-pixel annotation
EP4318313A1 (en) Data processing method, training method for neural network model, and apparatus
CN111260020A (zh) 卷积神经网络计算的方法和装置
US20230153965A1 (en) Image processing method and related device
CN107886491A (zh) 一种基于像素最近邻的图像合成方法
CN111814534A (zh) 视觉任务的处理方法、装置和电子系统
CN111680619A (zh) 基于卷积神经网络和双注意力机制的行人检测方法
KR20190044761A (ko) 이미지 처리 장치 및 방법
CN116912924B (zh) 一种目标图像识别方法和装置
Ren et al. Exploring simple triplet representation learning
CN115564024B (zh) 生成网络的特征蒸馏方法、装置、电子设备及存储介质
CN117671371A (zh) 一种基于代理注意力的视觉任务处理方法和系统
Malekijoo et al. Convolution-deconvolution architecture with the pyramid pooling module for semantic segmentation
Ye et al. A multi-attribute controllable generative model for histopathology image synthesis
Wang et al. SCNet: Scale-aware coupling-structure network for efficient video object detection
CN110633630A (zh) 一种行为识别方法、装置及终端设备
CN115222838A (zh) 视频生成方法、装置、电子设备及介质
CN114329070A (zh) 视频特征提取方法、装置、计算机设备和存储介质
CN114332989A (zh) 一种多任务级联卷积神经网络的人脸检测方法及系统
CN114627293A (zh) 基于多任务学习的人像抠图方法
CN114820755A (zh) 一种深度图估计方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant