CN116362324A - 一种用于生成对抗网络的蒸馏方法、装置、设备及介质 - Google Patents

一种用于生成对抗网络的蒸馏方法、装置、设备及介质 Download PDF

Info

Publication number
CN116362324A
CN116362324A CN202310304741.8A CN202310304741A CN116362324A CN 116362324 A CN116362324 A CN 116362324A CN 202310304741 A CN202310304741 A CN 202310304741A CN 116362324 A CN116362324 A CN 116362324A
Authority
CN
China
Prior art keywords
generator
teacher
student
discriminator
result graph
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310304741.8A
Other languages
English (en)
Inventor
林彦硕
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Wondershare Software Co Ltd
Original Assignee
Shenzhen Wondershare Software Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Wondershare Software Co Ltd filed Critical Shenzhen Wondershare Software Co Ltd
Priority to CN202310304741.8A priority Critical patent/CN116362324A/zh
Publication of CN116362324A publication Critical patent/CN116362324A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/096Transfer learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0475Generative networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/091Active learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/094Adversarial learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T3/00Geometric image transformations in the plane of the image
    • G06T3/04Context-preserving transformations, e.g. by using an importance map
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例公开一种用于生成对抗网络的蒸馏方法包括:获取第一导师生成器、第二导师生成器、鉴别器和学生生成器,其中,第一导师生成器包括第一预设通道数乘以第一预设层数的数据,第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,第一预设通道数大于第二预设通道数,第一预设层数小于第二预设层数;通过第一导师生成器进行学习,以得到第一结果图;通过使用第二导师生成器,以得到第二结果图;基于鉴别器对第一导师生成器得到的第一结果图和第二导师生成器得到的第二结果图进行损失计算和更新;至少基于鉴别器进行损失计算和更新后的第一导师生成器,根据核对齐知识蒸馏函数计算第一导师生成器与学生生成器的相似度。

Description

一种用于生成对抗网络的蒸馏方法、装置、设备及介质
技术领域
本申请实施例涉及图像或视频软件的图像处理技术领域,尤其涉及一种用于生成对抗网络的蒸馏方法、装置、设备及介质。
背景技术
生成数据、模型时,计算机可以将输入的原始影像换成输入的目标影像,且能有较清晰的输出结果。此外也能将现有模型加速。
对此,常用的手段为生成对抗网络(GAN,GenerativeAdversarial Networks)。GAN在生成优秀图像方面取得了巨大的成功,但是,由于计算成本高,内存使用量大,在资源有限的设备上部署GAN非常困难。尽管最近压缩GAN的努力取得了显著的成果,但它们仍然存在潜在的模型冗余,可以进一步压缩。
面对GAN的相关技术中,要么仅注重算法重建速度而牺牲了重建效果,要么过于注重效果而牺牲了算力,大大降低了用户的使用体验,因此,缺少一种兼顾速度和质量的用于生成对抗网络(GAN)的蒸馏方法。
发明内容
针对上述相关技术中存在的问题,本申请实施例提供了一种用于生成对抗网络的蒸馏方法,可以兼顾生成对抗网络的蒸馏的速度和质量,大大提升用户体验。
第一方面,本申请实施例提供了用于生成对抗网络的蒸馏方法,可包括:获取第一导师生成器、第二导师生成器、鉴别器和学生生成器,其中,所述第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数;通过所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图;基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新;至少基于所述第一导师生成器,优化所述学生生成器;所述优化包括:根据第一蒸馏函数,计算所述第一导师生成器与所述学生生成器的相似度,其中,所述第一蒸馏函数为核对齐知识蒸馏函数。
进一步地,所述至少基于所述鉴别器优化后的第一导师生成器,优化所述学生生成器,优化所述学生生成器,包括:基于所述鉴别器进行损失计算和更新后的第一结果图和第二结果图,对所述学生生成器进行优化。
进一步地,所述基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新,包括:基于第一鉴别器的各层的特征图对所述第一导师生成器得到的第一结果图进行计算,以得到第一结果图损失;以及基于第二鉴别器的各层的特征图对所述第二导师生成器得到的第二结果图进行计算,以得到第二结果图损失。
进一步地,所述基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新,包括:基于第一鉴别器的各层的特征图对所述第一导师生成器得到的第一结果图进行计算,以得到第一结果图损失;以及基于第二鉴别器的各层的特征图对所述第二导师生成器得到的第二结果图进行计算,以得到第二结果图损失。
进一步地,所述第一鉴别器和所述第二鉴别器所处理的图像的尺寸大小为相同或不同;或者所述第一鉴别器和所述第二鉴别器的前N层为共享鉴别层,其中,N≥1。
进一步地,所述方法还包括:所述方法还包括:当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像。
进一步地,所述当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像包括:所述学生生成器至少根据结构相似度损失函数、风格损失函数和/或平滑度损失函数,对所述输入图像进行重建,以得到输出图像。
第二方面,本申请实施例还提供了一种用于生成对抗网络的蒸馏装置,可包括:
获取模块,用于获取第一导师生成器、第二导师生成器、鉴别器和学生生成器,其中,所述第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数;学习模块,用于通过使用所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图;鉴别模块,用于基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新;以及优化模块,用于至少基于所述第一导师生成器,优化所述学生生成器;所述优化包括:根据第一蒸馏函数,计算所述第一导师生成器与所述学生生成器的相似度,其中,所述第一蒸馏函数为核对齐知识蒸馏函数。
第三方面,本申请实施例还提供了一种计算机设备,其中,包括:存储器以及处理器,所述存储器用于存储并支持处理器执行第一方面中任一项所述方法的程序,所述处理器被配置为用于执行所述存储器中存储的程序。
第四方面,本申请实施例还提供了一种具有处理器可执行的非易失的程序代码的计算机可读介质,其中,所述程序代码使所述处理器执行所述第一方面的任一所述方法。
本申请实施例中,由于第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数的设置,能够使得第一导师生成器和第二导师生成器在宽度和深度信息上实现互补,有利于学生生成器仅使用第一导师生成器和第二导师生成器进行优化时的蒸馏效果,同时学生生成器可以在无鉴别器的情况下进行损失优化,以解决非合作博弈均衡问题。同时,由于采用核对齐知识蒸馏函数计算所述第一导师生成器与所述学生生成器的相似度,能够快速高效的得到最优效果。
附图说明
为了更清楚地说明本申请实施例或相关技术中的技术方案,下面将对实施例或相关技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图示出的结构获得其他的附图。
图1为本申请实施例提供的一种用于生成对抗网络的蒸馏方法的一个流程示意图;
图2为本申请实施例提供的另一种用于生成对抗网络的蒸馏方法的一个流程示意图;
图3为本申请实施例提供的又一种用于生成对抗网络的蒸馏方法的一个流程示意图;
图4为本申请实施例提供的用于生成对抗网络的蒸馏装置的示意性框图;
图5为本申请实施例提供的计算机设备的示意性框图。
本申请目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请的一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在此本申请说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本申请。如在本申请说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
还应当更进一步理解,本申请的说明书和权利要求书及上述附图中,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本申请的描述中,除非另有说明,“多个”的含义是两个或两个以上。对于本领域的普通技术人员而言,可以根据具体情况理解上述术语在本申请中的具体含义。
生成对抗网络(GAN,GenerativeAdversarialNetworks),在生成优秀图像方面取得了巨大的成功,但是,由于计算成本高,内存使用量大,在资源有限的设备上部署GAN非常困难。尽管最近压缩GAN的努力取得了显著的成果,但它们仍然存在潜在的模型冗余,可以进一步压缩。
面对GAN的相关技术中,要么仅注重算法重建速度而牺牲了重建效果,要么过于注重效果而牺牲了算力,大大降低了用户的使用体验,因此,缺少一种兼顾速度和质量的用于生成对抗网络(GAN)的蒸馏方法。
具体的,相关技术中,设计一个包含各种宽度和深度的模块,将此模块替换到原模型的卷基层。训练替换模块后的模型,再从训练完的模型中的模块挑选适合的宽度和深度组成最后的模型架构。之后设定一个时间参数,从替换后的模型中寻找适合的模型架构。寻找后的模型架构会符合该时间参数。然而这并不能保证这个架构能得到好的结果。也就是说此方法在模型架构上只考虑了速度,并没有考虑效果。
有鉴于此,本申请实施例提供了一种用于生成对抗网络的蒸馏方法、装置、设备及介质。
该用于生成对抗网络的蒸馏方法的执行主体可以是本申请实施例提供的用于生成对抗网络的蒸馏装置,或者集成了该用于生成对抗网络的蒸馏装置的计算机设备,其中,该用于生成对抗网络的蒸馏装置可以采用硬件或者软件的方式实现,该计算机设备可以为终端或服务器,该终端可以是智能手机、平板电脑、掌上电脑、或者笔记本电脑等。
首先,对本申请可能涉及的相关概念进行简要介绍。
蒸馏:类比人类的学习过程,在知识蒸馏中称要进行压缩的模型为导师生成器(TeacherModel),压缩之后的模型为学生生成器(StudentModel),一般情况下,导师生成器的体积要远大于学生生成器。一般的知识蒸馏过程为首先利用数据集训练导师生成器,让导师生成器充分学习数据中包含的知识,然后在利用数据集训练学生生成器时,通过蒸馏方法将导师生成器中已经学习到的知识提取出来,指导学生生成器的训练,这样学生生成器相当于从导师生成器那里获取到了关于数据集的先验信息。也就是在传统的知识蒸馏中,导师生成器是预先在数据集上进行过训练的,然后在学生生成器的训练过程中利用自身学习到的知识对其进行指导,帮助提高学生生成器的准确率。
1Lloss:又称为L1范数损失、最小绝对偏差(LAD)或平均绝对值误差(MAE),无论对于什么样的输入值,都有着稳定的梯度,不会导致梯度爆炸问题,具有较为稳健性的解。但是,在中心点是折点,不能求导,梯度下降时要是恰好学习到w=0就无法继续进行。
L2loss:也称为均方误差MSE(L2LOSS),均方误差(MeanSquare Error,MSE)是模型预测值f(x)和样本真实值y之间差值平方的平均值。其各点都连续光滑,方便求导,具有较为稳定的解。但是,不是特别的稳健,因为当函数的输入值距离真实值较远的时候,对应loss值很大在两侧,则使用梯度下降法求解的时候梯度很大,可能导致梯度爆炸。
SmoothL1Loss:平滑版本的L1LOSS,当预测值f(xi)和真实值yi差别较小的时候(绝对值差小于1),其实使用的是L2loss;差别大的时候,使用的是L1loss的平移。因此,SmoothL1loss其实是L1loss和L2loss的结合,同时拥有两者的部分优点:真实值和预测值差别较小时(绝对值差小于1),梯度也会比较小(损失函数比普通L1loss在此处更圆滑),可以收敛得更快。真实值和预测值差别较大时,梯度值足够小(普通L2loss在这种位置梯度值就很大,容易梯度爆炸)。
上述仅是对本申请实施例的技术原理和示例性的应用框架的说明,下面通过多个实施例来进一步对本申请实施例具体技术方案进行详细描述。请参阅图1和图2,本申请实施例中一种用于生成对抗网络的蒸馏方法的一个实施例,可包括:
S100:获取第一导师生成器、第二导师生成器、鉴别器和学生生成器,其中,所述第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数。
在一些实施方式中,可以基于损失函数和训练设置来训练生成第一导师生成器、第二导师生成器和鉴别器。其中,第一导师生成器和第二导师生成器均旨在学习一个函数,将数据从源域X映射到目标域Y。学生生成器仅利用第一导师生成器、第二导师生成器进行优化,因此可以在无鉴别器的环境中进行训练。学生生成器的优化不需要同时使用真实标签y。也就是说,学生生成器只学习具有类似结构(第一导师生成器、第二导师生成器)的大容量生成器的输出,这大大降低了直接拟合y的难度。具体来说,我们在每个迭代步骤中反向传播第一导师生成器、第二导师生成器的蒸馏损失。通过这种方式,学生生成器可以模仿第一导师生成器、第二导师生成器的训练过程,逐步学习。
在一些实施方式中,第一导师生成器和第二导师生成器具备互补结构和来自不同层次的知识。由此,有助于从真实标签中捕捉更多互补的图像线索,并从不同角度提高图像重建性能。具体的,第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数,也即,第一导师生成器具有更广泛的信息,第二导师生成器具有深度更深的信息。
在一些更具体的实施例中,假设给定一个学生生成器,可以根据该初始的学生生成器扩展了学生生成器的通道以获得更宽的第一导师生成器。具体而言,学生生成器的卷积层(也即上述预设层)的每个通道乘以通道扩展因子η。可以根据学生生成器得到第二导师生成器,具体的,在每个下采样和上采样层之后插入几个残差块到学生生成器中,以构建更深层的第二导师生成器,其容量与第一导师生成器相当。
S200:通过所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图。
具体的,所述第一导师生成器通过使用第一感知损失函数进行重建,以得到所述第一结果图;所述第二导师生成器通过使用第二感知损失函数进行重建,以得到所述第二结果图。
在一些实施方式中,第一感知损失函数和/或第二感知损失函数可以采用L1loss对图像进行重建以得到第一结果图和第二结果图。L1loss是针对生成结果和真实结果(groundtruth)的像素做一对一对的算差值和平均。此实施方式基本上能重建颜色,但是对内容的重建效果不佳。更进一步的,在一些实施方式中,第一感知损失函数和/或第二感知损失函数可以增加包括以下的重建损失函数。针对真实物体(例如杯子等等)和背景,第一感知损失函数和/或第二感知损失函数可以包括图像相似度度量标准损失函数(lpipsloss),因为此损失函数使用针对真实物体进行分类的预训练模型进行特征比对,所以对真实物体和背景能有较好的重建能力。
S300:基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新。
在更进一步的实施方式中,为了能针对各种卡通人脸进行运算,本实施方式中,将鉴别器的各层的特征图撷取出来,对生成的特征图和真实特征图进行比对和统计。具体的,包括:基于第一鉴别器的各层的特征图对所述第一导师生成器得到的第一结果图进行计算,以得到第一结果图损失;以及基于第二鉴别器的各层的特征图对所述第二导师生成器得到的第二结果图进行计算,以得到第二结果图损失。由于从鉴别器提取出的特征,所以是针对输入图像(例如卡通人脸)的重建效果会比前述lpipsloss的效果更优。
更具体的,所述第一鉴别器和所述第二鉴别器所处理的图像的尺寸大小为相同或不同。当所述第一鉴别器和所述第二鉴别器所处理的图像的尺寸大小不同时,可以让生成器学到不同尺寸的输入图的转换。
在一些实施方式中,所述第一鉴别器和所述第二鉴别器的前N层为共享鉴别层,其中,N≥1。具体的,该鉴别器为部分共享鉴别器,其被设计为共享前几层,并分离两个分支,以分别获得第一导师生成器和第一导师生成器的鉴别器输出,从而得到第一结果图损失和第二结果图损失。这种共享设计不仅提供了鉴别器的高度灵活性,而且还利用了输入图像的相似特性来改进生成器的训练。
在一些实施方式中,鉴别器会将该第一结果图损失反向传输回第一导师生成器,第二结果图损失反向传输至第二导师生成器,以此来逐步优化迭代第一导师生成器和第二导师生成器。
S400:至少基于所述鉴别器进行损失计算和更新后的第一导师生成器,优化所述学生生成器;所述优化包括:根据第一蒸馏函数,计算所述第一导师生成器与所述学生生成器的相似度,其中,所述第一蒸馏函数为核对齐知识蒸馏函数。
具体的,所述至少基于所述鉴别器优化后的第一导师生成器,优化所述学生生成器,优化所述学生生成器,包括:基于所述鉴别器进行损失计算和更新后的第一结果图和第二结果图,对所述学生生成器进行优化。
在一些实施方式中,可以透过1*1卷机扩充学生生成器的通道数,使得学生生成器的通道数和第一导师生成器的通道数相等,以此求得第一导师生成器和学生生成器的相似度,从而评估当前的学生生成器是否可以作为最终的生成器进行蒸馏学习。此实施方式能够实现一定的重建效果,但是因为使用1*1卷基层改变通道数后学习,这导致有资讯累积在1*1卷基层内,不是真正将两个模型的特征图进行比较,所以效果会比较差。
在另一些实施方式中,采用核对齐知识蒸馏函数计算所述第一导师生成器与所述学生生成器的相似度。具体而言,核对齐知识蒸馏函数为:
KA(X,Y)=(||YTX||F 2)/(||XTX||F||YTY||F)(1)
其中,Y是学生生成器的特征图,X是第一导师生成器的特征图。
由于使用核对齐知识蒸馏函数直接比对第一导师生成器和学生生成器的特征图相似度,如此解决了1*1卷基层的资料残留问题。
在一些实施方式中,本方法还可以包括:
S500:当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像。
具体的,所述当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像,包括:所述学生生成器至少根据结构相似度损失函数、风格损失函数和/或平滑度损失函数,对所述输入图像进行重建,以得到输出图像。在一些实施方式中,输入图像可以为输入的原始影像(客户照片),输出图像可以为计划换成的输入的目标影像(任意人脸照片),且能有较清晰的输出结果。
在一些实施方式中,结构相似度损失函数SSIM(StructuralSimilarityloss)主要处理亮度、对比度和结构。风格损失(styleloss)计算生成图的风格和真实图风格是否相似。最后平滑度损失函数(totalvariationloss,TVloss)处理图的平滑程度。更具体的,还包括使用lpips真实背景和物体的重建。在一些实施方式中,为处理卡通图案,可以增加比对鉴别器的特征图差异作为重建函数之一。在一些实施方式中,还可以增加颜色损失函数(L1loss、smoothL1loss或l2loss)等作为重建函数。
请参阅图2,图2为本申请实施例提供的另一种用于生成对抗网络的蒸馏方法的一个流程示意图。
在本实施例中,可以通过学生生成器增加宽度或者深度得到第一导师生成器和第二导师生成器。
在一些试试例中,第一导师生成器和第二导师生成器分别通过第一感知损失函数和第二感知损失函数重建得到第一结果图和第二结果图。
鉴别器能够对第一导师生成器和第二导师生成器得到的第一结果图和第二结果图进行损失计算和更新。其中,损失计算主要是计算第一结果图与输入图像的损失函数,和计算第二结果图与输入图像的损失函数,并分别将第一损失结果和第二损失结果反向传输至第一导师生成器和第二导师生成器以持续优化第一导师生成器和第二导师生成器。也就是说,本实施方式中,第一导师生成器和第二导师生成器不是预先设定固定的,而是会不断根据第一结果图损失和第二结果图损失在线更新的,由此,能够确保第一导师生成器和第二导师生成器的学习效果。
学生生成器可以基于挡墙的学生生成器结果和第一结果图以及第二结果图计算第一蒸馏损失和第二蒸馏损失,并根据第一蒸馏损失和第二蒸馏损失的反向传输不断优化学生生成器。
进一步的,学生生成器还可以与第一导师生成器结合第一蒸馏函数进行相似度计算,具体的,为二者的特征图的相似度计算。当该相似度满足预设阈值的时候,表明该学生生成器的蒸馏效果较佳,能够用于对输入图像的重建,以得到输出效果,此时,学生生成器未透过鉴定器,由此避免鉴别器过强,导致学生生成器学不到东西。
请参阅图3,图3为本申请实施例提供的又一种用于生成对抗网络的蒸馏方法的一个流程示意图。本实施方式和图2所示的实施方式类似,区别在于:图3中具体给出了第一导师生成器和第二导师生成器的预设层数和预设通道数。从图中可以看出,所述第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数。
在一些实施方式中,可以利用第一蒸馏函数对第一导师生成器和学生生成器在中间层进行蒸馏学习,由此,学生生成器和导师生成器都是逐步同时优化的,也可以进一步提升优化效率和效果。
在一些实施方式中,KD损失表示蒸馏损失,与前述第一蒸馏损失或第二蒸馏损失一致。GAN损失表示生成算法损失,其与前述第一结果图损失或第二结果图损失一致。
综上,本申请实施例中,由于第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数的设置,能够使得第一导师生成器和第二导师生成器在宽度和深度信息上实现互补,有利于学生生成器仅使用第一导师生成器和第二导师生成器进行优化时的蒸馏效果,同时学生生成器可以在无鉴别器的情况下进行损失优化,以解决非合作博弈均衡问题。同时,由于采用核对齐知识蒸馏函数计算所述第一导师生成器与所述学生生成器的相似度,能够快速高效的得到最优效果。
图4是本申请实施例提供的一种用于生成对抗网络的蒸馏装置的示意性框图。如图4所示,对应于以上用于生成对抗网络的蒸馏方法,本申请还提供一种用于生成对抗网络的蒸馏装置100。该用于生成对抗网络的蒸馏装置100包括用于执行上述用于生成对抗网络的蒸馏方法的单元,该装置可以被配置于台式电脑、平板电脑、手提电脑、等终端中。具体地,请参阅图4,该用于生成对抗网络的蒸馏装置100包括获取模块101、学习模块102、鉴别模块103以及优化模块104,其中:
获取模块101,用于获取第一导师生成器、第二导师生成器、鉴别器和学生生成器,其中,所述第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数;
学习模块102,用于通过使用所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图;
鉴别模块103,用于基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新;
以及
优化模块104,用于至少基于所述第一导师生成器,优化所述学生生成器;所述优化包括:根据第一蒸馏函数,计算所述第一导师生成器与所述学生生成器的相似度,其中,所述第一蒸馏函数为核对齐知识蒸馏函数。
在一些实施例中,学习模块102在实现所述通过使用所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图时,具体包括:所述第一导师生成器通过使用第一感知损失函数进行重建,以得到所述第一结果图;所述第二导师生成器通过使用第二感知损失函数进行重建,以得到所述第二结果图。
在一些实施例中,鉴别模块103在执行在所述基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新时,具体用于:基于第一鉴别器的各层的特征图对所述第一导师生成器得到的第一结果图进行计算,以得到第一结果图损失;以及基于第二鉴别器的各层的特征图对所述第二导师生成器得到的第二结果图进行计算,以得到第二结果图损失。
在一些实施例中,鉴别模块103中,所述第一鉴别器和所述第二鉴别器所处理的图像的尺寸大小为相同或不同;或者所述第一鉴别器和所述第二鉴别器的前N层为共享鉴别层,其中,N≥1。
在一些实施例中,优化模块104在实现所述至少基于所述鉴别器优化后的第一导师生成器,优化所述学生生成器,优化所述学生生成器时,具体包括:基于所述鉴别器进行损失计算和更新后的第一结果图和第二结果图,对所述学生生成器进行优化。
在一些实施例中,所述生成对抗网络的蒸馏装置100还包括重建模块105,用于当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像。
在一些实施方式中,重建模块105在实现所述当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像时,具体包括:所述学生生成器至少根据结构相似度损失函数、风格损失函数和/或平滑度损失函数,对所述输入图像进行重建,以得到输出图像。
需要说明的是,所属领域的技术人员可以清楚地了解到,上述用于生成对抗网络的蒸馏装置100和各单元的具体实现过程,可以参考前述方法实施例中的相应描述,为了描述的方便和简洁,在此不再赘述。
上述用于生成对抗网络的蒸馏装置100可以实现为一种计算机程序的形式,该计算机程序可以在如图5所示的计算机设备上运行。
请参阅图5,图5是本申请实施例提供的一种计算机设备的示意性框图。该计算机设备200可以是终端,也可以是服务器,其中,终端可以是智能手机、平板电脑、笔记本电脑、台式电脑、个人数字助理和穿戴式设备等具有通信功能的电子设备。服务器可以是独立的服务器,也可以是多个服务器组成的服务器集群。
参阅图5,该计算机设备200包括通过系统总线201连接的处理器202、存储器和网络接口205,其中,存储器可以包括非易失的程序代码的计算机可读介质203和内存储器204。
该非易失的程序代码的计算机可读介质203可存储操作系统2031和计算机程序2032。该计算机程序2032包括程序指令,该程序指令被执行时,可使得处理器202执行一种基于时频结合的用于生成对抗网络的蒸馏方法。
该处理器202用于提供计算和控制能力,以支撑整个计算机设备200的运行。
该内存储器204为非易失的程序代码的计算机可读介质203中的计算机程序2032的运行提供环境,该计算机程序2032被处理器202执行时,可使得处理器202执行一种用于生成对抗网络的蒸馏方法。
该网络接口205用于与其它设备进行网络通信。本领域技术人员可以理解,图5中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备200的限定,具体的计算机设备200可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
其中,所述处理器202用于运行存储在存储器中的计算机程序2032,以实现如下步骤:
获取第一导师生成器、第二导师生成器、鉴别器和学生生成器,其中,所述第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数;
通过所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图;
基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新;
至少基于所述鉴别器进行损失计算和更新后的第一导师生成器,优化所述学生生成器;所述优化包括:根据第一蒸馏函数,计算所述第一导师生成器与所述学生生成器的相似度,其中,所述第一蒸馏函数为核对齐知识蒸馏函数。
在一些实施例中,所述处理器202用于在实现所述通过使用所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图时,具体包括:所述第一导师生成器通过使用第一感知损失函数进行重建,以得到所述第一结果图;所述第二导师生成器通过使用第二感知损失函数进行重建,以得到所述第二结果图。
在一些实施例中,所述处理器202在执行在所述基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新时,具体用于:基于第一鉴别器的各层的特征图对所述第一导师生成器得到的第一结果图进行计算,以得到第一结果图损失;以及基于第二鉴别器的各层的特征图对所述第二导师生成器得到的第二结果图进行计算,以得到第二结果图损失。
在一些实施例中,所述处理器202用于,所述第一鉴别器和所述第二鉴别器所处理的图像的尺寸大小为相同或不同;或者所述第一鉴别器和所述第二鉴别器的前N层为共享鉴别层,其中,N≥1。
在一些实施例中,所述处理器202在实现所述至少基于所述鉴别器优化后的第一导师生成器,优化所述学生生成器,优化所述学生生成器时,具体包括:基于所述鉴别器进行损失计算和更新后的第一结果图和第二结果图,对所述学生生成器进行优化。
在一些实施例中,所述处理器202用于当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像。
在一些实施方式中,所述处理器202用于在实现所述当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像时,具体包括:所述学生生成器至少根据结构相似度损失函数、风格损失函数和/或平滑度损失函数,对所述输入图像进行重建,以得到输出图像。
应当理解,在本申请实施例中,处理器202可以是中央处理单元(CentralProcessingUnit,CPU),该处理器202还可以是其他通用处理器、数字信号处理器(DigitalSignalProcessor,DSP)、专用集成电路(ApplicationSpecificIntegratedCircuit,ASIC)、现成可编程门阵列(Field-ProgrammableGateArray,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。其中,通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
所述存储介质可以是U盘、移动硬盘、只读存储器(Read-OnlyMemory,ROM)、磁碟或者光盘等各种可以存储程序代码的计算机可读存储介质。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的几个实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的。例如,各个单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式。例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。
本申请实施例方法中的步骤可以根据实际需要进行顺序调整、合并和删减。本申请实施例装置中的单元可以根据实际需要进行合并、划分和删减。另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以是两个或两个以上单元集成在一个单元中。
该集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个存储介质中。基于这样的理解,本申请的技术方案本质上或者说对相关技术做出贡献的部分,或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,终端,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以权利要求的保护范围为准。

Claims (10)

1.一种用于生成对抗网络的蒸馏方法,其特征在于,包括:
获取第一导师生成器、第二导师生成器、鉴别器和学生生成器,其中,所述第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数;
通过所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图;
基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新;
至少基于所述鉴别器进行损失计算和更新后的第一导师生成器,优化所述学生生成器;所述优化包括:根据第一蒸馏函数,计算所述第一导师生成器与所述学生生成器的相似度,其中,所述第一蒸馏函数为核对齐知识蒸馏函数。
2.根据权利要求1所述的方法,其特征在于,所述至少基于所述鉴别器优化后的第一导师生成器,优化所述学生生成器,优化所述学生生成器,包括:基于所述鉴别器进行损失计算和更新后的第一结果图和第二结果图,对所述学生生成器进行优化。
3.根据权利要求1所述的方法,其特征在于,所述通过使用所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图,至少包括:
所述第一导师生成器通过使用第一感知损失函数进行重建,以得到所述第一结果图;
所述第二导师生成器通过使用第二感知损失函数进行重建,以得到所述第二结果图。
4.根据权利要求1所述的方法,其特征在于,所述基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新,包括:
基于第一鉴别器的各层的特征图对所述第一导师生成器得到的第一结果图进行计算,以得到第一结果图损失;以及
基于第二鉴别器的各层的特征图对所述第二导师生成器得到的第二结果图进行计算,以得到第二结果图损失。
5.根据权利要求1所述的方法,其特征在于,所述第一鉴别器和所述第二鉴别器所处理的图像的尺寸大小为相同或不同;或者
所述第一鉴别器和所述第二鉴别器的前N层为共享鉴别层,其中,N≥1。
6.根据权利要求1所述的方法,其特征在于,所述方法还包括:当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像。
7.根据权利要求6所述的方法,其特征在于,所述当所述第一导师生成器和所述学生生成器的相似度大于预设阈值时,根据所述学生生成器对输入图像进行重建,以得到输出图像,包括:
所述学生生成器至少根据结构相似度损失函数、风格损失函数和/或平滑度损失函数,对所述输入图像进行重建,以得到输出图像。
8.一种用于生成对抗网络的蒸馏装置,其特征在于,包括:
获取模块,用于获取第一导师生成器、第二导师生成器、鉴别器和学生生成器,其中,所述第一导师生成器包括第一预设通道数乘以第一预设层数的数据,所述第二导师生成器包括第二预设通道数乘以第二预设层数的数据,其中,所述第一预设通道数大于所述第二预设通道数,所述第一预设层数小于所述第二预设层数;
学习模块,用于通过使用所述第一导师生成器进行学习,以得到第一结果图;通过使用所述第二导师生成器,以得到第二结果图;
鉴别模块,用于基于所述鉴别器对所述第一导师生成器得到的第一结果图和所述第二导师生成器得到的第二结果图进行损失计算和更新;
以及
优化模块,用于至少基于所述第一导师生成器,优化所述学生生成器;所述优化包括:根据第一蒸馏函数,计算所述第一导师生成器与所述学生生成器的相似度,其中,所述第一蒸馏函数为核对齐知识蒸馏函数。
9.一种计算机设备,其特征在于,包括:存储器以及处理器,所述存储器用于存储并支持处理器执行权利要求1至7中任一项所述方法的程序,所述处理器被配置为用于执行所述存储器中存储的程序。
10.一种具有处理器可执行的非易失的程序代码的计算机可读介质,其特征在于,所述程序代码使所述处理器执行权利要求1至7任一所述方法。
CN202310304741.8A 2023-03-14 2023-03-14 一种用于生成对抗网络的蒸馏方法、装置、设备及介质 Pending CN116362324A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310304741.8A CN116362324A (zh) 2023-03-14 2023-03-14 一种用于生成对抗网络的蒸馏方法、装置、设备及介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310304741.8A CN116362324A (zh) 2023-03-14 2023-03-14 一种用于生成对抗网络的蒸馏方法、装置、设备及介质

Publications (1)

Publication Number Publication Date
CN116362324A true CN116362324A (zh) 2023-06-30

Family

ID=86906459

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310304741.8A Pending CN116362324A (zh) 2023-03-14 2023-03-14 一种用于生成对抗网络的蒸馏方法、装置、设备及介质

Country Status (1)

Country Link
CN (1) CN116362324A (zh)

Similar Documents

Publication Publication Date Title
US10991150B2 (en) View generation from a single image using fully convolutional neural networks
CN111784821B (zh) 三维模型生成方法、装置、计算机设备及存储介质
CN111179177A (zh) 图像重建模型训练方法、图像重建方法、设备及介质
WO2023103576A1 (zh) 视频处理方法、装置、计算机设备及存储介质
JP2022522564A (ja) 画像処理方法及びその装置、コンピュータ機器並びにコンピュータプログラム
CN113066034A (zh) 人脸图像的修复方法与装置、修复模型、介质和设备
CN115345866B (zh) 一种遥感影像中建筑物提取方法、电子设备及存储介质
CN113344869A (zh) 一种基于候选视差的行车环境实时立体匹配方法及装置
CN114782864B (zh) 一种信息处理方法、装置、计算机设备及存储介质
CN115131218A (zh) 图像处理方法、装置、计算机可读介质及电子设备
CN110874575A (zh) 一种脸部图像处理方法及相关设备
CN116485741A (zh) 一种无参考图像质量评价方法、系统、电子设备及存储介质
CN114783022B (zh) 一种信息处理方法、装置、计算机设备及存储介质
CN114821404A (zh) 一种信息处理方法、装置、计算机设备及存储介质
CN115496925A (zh) 图像处理方法、设备、存储介质及程序产品
CN116797768A (zh) 全景图像减少现实的方法和装置
KR101795952B1 (ko) 2d 영상에 대한 깊이 영상 생성 방법 및 장치
Liu et al. Facial image inpainting using multi-level generative network
CN113538254A (zh) 图像恢复方法、装置、电子设备及计算机可读存储介质
US20230073175A1 (en) Method and system for processing image based on weighted multiple kernels
CN116362324A (zh) 一种用于生成对抗网络的蒸馏方法、装置、设备及介质
CN114898244A (zh) 一种信息处理方法、装置、计算机设备及存储介质
CN115311152A (zh) 图像处理方法、装置、电子设备以及存储介质
CN115035170A (zh) 基于全局纹理与结构的图像修复方法
CN113822790A (zh) 一种图像处理方法、装置、设备及计算机可读存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination