CN113160041A - 一种模型训练方法及模型训练装置 - Google Patents

一种模型训练方法及模型训练装置 Download PDF

Info

Publication number
CN113160041A
CN113160041A CN202110495293.5A CN202110495293A CN113160041A CN 113160041 A CN113160041 A CN 113160041A CN 202110495293 A CN202110495293 A CN 202110495293A CN 113160041 A CN113160041 A CN 113160041A
Authority
CN
China
Prior art keywords
model
image
loss
target frame
frame image
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110495293.5A
Other languages
English (en)
Other versions
CN113160041B (zh
Inventor
王鑫宇
刘炫鹏
杨国基
刘致远
刘云峰
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Zhuiyi Technology Co Ltd
Original Assignee
Shenzhen Zhuiyi Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Zhuiyi Technology Co Ltd filed Critical Shenzhen Zhuiyi Technology Co Ltd
Priority to CN202110495293.5A priority Critical patent/CN113160041B/zh
Publication of CN113160041A publication Critical patent/CN113160041A/zh
Application granted granted Critical
Publication of CN113160041B publication Critical patent/CN113160041B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T3/00Geometric image transformations in the plane of the image
    • G06T3/04Context-preserving transformations, e.g. by using an importance map
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明实施例公开了一种模型训练方法及训练装置,用于在图像翻译模型的训练数据较少时,提升图像翻译模型的图像翻译质量。本发明实施例方法包括:利用训练数集对图像翻译模型的生成器和判别器做训练,并将训练后的图像翻译模型视为老师模型,训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和目标帧图像的前N帧图像数据;利用训练数集中的第一数据对图像翻译模型的生成器和判别器进行训练,并将训练后的图像翻译模型视为学生模型,第一数据包括目标帧图像、目标帧图像的轮廓线数据和目标帧图像的前M帧图像数据,M为大于等于1且小于等于N的整数;利用老师模型对学生模型进行知识蒸馏,得到知识蒸馏后的学生模型。

Description

一种模型训练方法及模型训练装置
技术领域
本发明涉及图像翻译技术领域,尤其涉及一种模型训练方法及模型训练装置。
背景技术
所谓图像翻译,指从一副图像到另一副图像的转换。可以类比机器翻译,将一种语言转换为另一种语言。
现有技术中较为经典的图像翻译模型有pix2pix,pix2pixHD,vid2vid。pix2pix提出了一个统一的框架解决了各类图像翻译问题,pix2pixHD则在pix2pix的基础上,较好地解决了高分辨率图像转换(翻译)的问题,vid2vid则在pix2pixHD的基础上,较好地解决了高分辨率的视频转换问题。
数字人,是一种利用信息科学的方法对真实人体的形态和功能进行姿态仿真的虚拟人。目前的图像翻译模型,可以对图像中的数字人进行虚拟仿真,但现有技术中的图像翻译模型,如果训练数据的类型较少,会导致训练得到的图像翻译模型在数字人姿态仿真(或数字人姿态生成)时准确率较低。
发明内容
本发明实施例提供了一种模型训练方法及模型训练装置,用于在图像翻译模型的训练数据较少时,也可以提升图像翻译模型的图像翻译质量,从而使得图像翻译模型在实现数字人姿态仿真时,提升数字人姿态仿真的准确率。
本申请实施例第一方面提供了一种模型训练方法,包括:
利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;
利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;
利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
优选的,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据。
优选的,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧降低像素后的图像和所述前M帧图像的轮廓线数据。
优选的,所述图像翻译模型中的生成器为编码模型-解码模型结构,所述利用所述老师模型对所述学生模型进行知识蒸馏,包括:
将所述老师模型中的判别器作为所述学生模型中的判别器;
根据所述学生模型的损失函数,计算所述学生模型的第一损失;
计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;
计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
优选的,所述根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新,包括:
获取所述第一损失、所述第二损失和所述第三损失对应的权重;
根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;
根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
优选的,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。
本申请实施例第二方面提供了一种模型训练装置,包括:
第一训练单元,用于利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;
第二训练单元,用于利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;
知识蒸馏单元,用于利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
优选的,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据。
优选的,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧降低像素后的图像和所述前M帧图像的轮廓线数据。
优选的,所述图像翻译模型中的生成器为编码模型-解码模型结构,所述知识蒸馏单元,包括:
设置模块,用于将所述老师模型中的判别器作为所述学生模型中的判别器;
第一计算模块,用于根据所述学生模型的损失函数,计算所述学生模型的第一损失;
第二计算模块,用于计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;
第三计算模块,用于计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
更新模块,用于根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
具体的,所述更新模块具体包括:
获取子模块,用于获取所述第一损失、所述第二损失和所述第三损失对应的权重;
计算子模块,用于根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;
更新子模块,用于根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
优选的,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。
本申请实施例第三方面提供了一种计算机装置,包括处理器,所述处理器在执行存储于存储器上的计算机程序时,用于实现本申请实施例第一方面所述的模型训练方法。
本申请实施例第四方面提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时,用于实现本申请实施例第一方面所述的模型训练方法。
从以上技术方案可以看出,本发明实施例具有以下优点:
本申请实施例中,利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
本申请实施例中利用老师模型对所述学生模型进行知识蒸馏,使得学生模型在输入数据类型减少时,也可以达到接近老师模型的高准确率的图像翻译质量。
附图说明
图1为本申请实施例中模型训练方法的一个实施例示意图;
图2为本申请实施例中神经网络结构的示意图;
图3为本申请实施例中模型训练方法的另一个实施例示意图
图4为本申请实施例中编码模型-解码模型的编码-解码过程示意图;
图5为本申请实施例中模型训练方法的另一个实施例示意图;
图6为本申请实施例中模型训练装置的一个实施例示意图。
具体实施方式
本发明实施例提供了一种模型训练方法及训练装置,用于在图像翻译模型的输入数据较少时,也能达到高准确率的图像翻译质量。
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的实施例能够以除了在这里图示或描述的内容以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
一般的,在对图像翻译模型进行训练时,一般都是采用训练数集对该模型进行训练,其中,该图像翻译模型为生成对抗(GAN)网络模型,而训练数集中的数据一般包括:目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像(distancemap)数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像(distancemap)数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像。
为了提升对图像翻译模型的训练速度,可以减少训练数集中的训练数据,以用于提升对图像翻译模型的训练速度,但若减少训练数集中的训练数据,往往会存在训练完成后的图像翻译模型在输出推理图像时,所输出的图像质量差的问题。
针对该问题,本申请实施例提出了一种模型训练方法及训练装置,用于在减少训练数集中数据的情况下,提升图像翻译模型输出图像的质量。
为方便理解,下面对本申请实施例中的模型训练方法进行描述:请参阅图1,本申请实施例中一种模型训练方法的一个实施例,包括:
101、利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;
现有的图像翻译模型都采用的是生成对抗(GAN)网络模型,其中,GAN网络应用到深度学习神经网络上来说,就是通过生成器G(Generator)和判别器D(Discriminator)不断博弈,进而使G学习到数据的分布,如果用到图片生成上,则训练完成后,G可以从一段随机数中生成逼真的图像。
本申请实施例中,先利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像(distancemap)数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像(distancemap)数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像。
下面对训练数集中的数据进行说明:
假设训练数集中任一帧图像的大小为X*X*3,其中,X为图片的像素大小,而3为图片通道的个数,其中3通道代表图片为RGB图像,则训练数集中目标帧图像为3维数据,而训练数集中任一帧图像的距离图像(distancemap)数据为4维数据,则目标帧图像的距离图像为4维数据,而练集中任一帧图像的轮廓线数据为1维数据,则目标帧图像的轮廓线数据为1维数据;对应的,目标帧图像的前N帧图像数据一起为N*8维。
需要说明的是,训练数集中的目标帧图像不参与图像翻译模型的图像推理过程,只用于计算对应的损失,所以训练数集中参与训练的数据一起为8N+5维。
具体的,对图像翻译模型中的生成器和判别器进行训练的过程如下所述:
将训练数集中的数据(即8N+5维数据)输入至图像翻译模型的生成器和判别器中,以得到目标帧的生成图像,然后根据预设的损失函数计算目标帧图像与目标帧的生成图像之间的第一损失,根据第一损失和反向传播算法,对所述生成器中的卷积层的权重进行梯度更新。
需要说明的是,本申请实施例中的目标帧图像与目标帧的生成图像是两个不同的对象,其中,目标帧图像为目标帧的真实图像,即视频数据中目标帧的真实图像,而目标帧的生成图像,是将训练数集中的数据输入至图像翻译模型的生成器中,以生成的图像。
具体的,本申请实施例中的图像翻译模型包括pix2pix,pix2pixHD和vid2vid中的至少一种,对于具体的图像翻译模型,其对应的损失函数也有所不同:
对于pix2pix而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss,及使得输出多样化的GANLoss;
对于pix2pixHD而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、及Content Loss;
对于vid2vid而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、Content Loss、视频Loss和光流Loss;
对于每一种Loss函数在现有技术中都有详细描述,此处不再赘述。
为方便理解梯度更新的过程,先对GAN网络中的生成器和判别器做简单描述:
图像翻译模型中的生成器和判别器采用的是神经网络算法,而多层感知器(Multi-Layer Perceptron,MLP)也叫人工神经网络(Artificial Neural Network,ANN),一般包括输入层、输出层,及设于输入层和输出层之间的多个隐层。最简单的MLP需要有一层隐层,即输入层、隐层和输出层才能称为一个简单的神经网络。
接下来以图2中的神经网络为例,对数据的传导过程进行描述:
1、神经网络的前向输出
其中,第0层(输入层),我们将x1、x2、x3向量化为X;
0层和1层(隐层)之间,存在权重w1、w2、w3,将权重向量化为W[1],其中W[1]表示第一层的权重;
0层和1层(隐层)之间,还存在偏置b1、b2、b3,将其向量化为b[1],其中b[1]表示第一层的权重;
对于第1层,计算公式为:
Z[1]=W[1]X+b[1];
A[1]=sigmoid(Z[1]);
其中,Z为输入值的线型组合,A为Z通过激活函数sigmoid的值,对于第一层的输入值X,输出值为A,也是下一层的输入值,在sigmoid激活函数中,其取值在[0,1]之间,可以将其理解为一个阀,就像人的神经元一样,当一个神经元受到刺激,并不是立刻感觉到,而是这个刺激超声了阀值,才会让神经元向上级传播。
1层和2层(输出层)之间,与0层和1层之间类似,其计算公式如下:
Z[2]=W[2]X+b[2]
A[2]=sigmoid(Z[2])
yhat=A[2];
其中yhat即为本次神经网络的输出值。
2、损失函数
在神经网络训练的过程中,一般通过损失函数来衡量这个神经网络是否训练到位。
一般情况下,我们选择如下函数作为作为损失函数:
Figure BDA0003054003460000091
其中,y为图片的真实特征值,
Figure BDA0003054003460000092
为生成图片的特征值;
当y=1时,若yhat越接近1,
Figure BDA0003054003460000093
越接近0,表示预测效果越好,当损失函数达到最小值时,说明生成器所生成的当前帧的生成图像越接近于当前帧的原始图像。
3、反向传播算法
在上述神经网络模型中,可以通过计算损失函数来得到神经网络的训练效果,同时还可以通过反向传播算法,来更新参数,使得神经网络模型更能得到我们想要的预测值。其中,梯度下降算法即为一种优化权重W和偏置b的方法。
具体的,梯度下降算法是对损失函数求偏导数,然后用偏导数来更新w1、w2和b。
为方便理解,我们将损失函数
Figure BDA0003054003460000101
公式化为以下的公式:
z=w1x1+w2x2+b;
Figure BDA0003054003460000102
Figure BDA0003054003460000103
然后,分别将对α和z求导:
Figure BDA0003054003460000104
Figure BDA0003054003460000105
再对w1、w2和b求导:
Figure BDA0003054003460000106
Figure BDA0003054003460000107
Figure BDA0003054003460000108
接下来用梯度下降算法更新权重参数w和偏置参数b:
其中,w1:=w1-βdw1
w2:=w2-βdw2
b:=b-βdb。
其中,β表示学习率,即学习的步长,在实际训练过程中,如果学习率过大,则会在最优解附近来回震荡,无法达到最优解,如果学习率过小,则可能需要很多次的迭代,才能达到最优解,所以在实际训练过程中,学习率也是一个很重要的选择参数。
而本申请中对生成器和判别器的训练过程,即根据不同图像翻译模型中的损失函数计算对应损失,然后利用反向传播算法,对生成器和判别器中卷积层的权重进行更新的过程,而具体的更新过程可以参阅上述损失函数和反向传播算法的计算过程。
102、利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;
为了提升图像翻译模型的训练速度,可以利用训练数集中的第一数据对图像翻译模型进行训练,并将训练完成后的图像翻译模型确定为学生模型,其中,第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数。
下面对训练数集中的第一数据进行说明:
具体的,第一数据包括目标帧图像、目标帧图像的轮廓线数据和目标帧图像的目标帧图像的前M帧图像数据。
相比于训练数集中的数据而言,第一数据少了目标帧图像的distancemap数据,且第一数据仅利用了目标帧图像的前M帧图像数据,其中,M为大于等于1且小于等于N的整数。
因为M为大于等于1且小于等于N的整数,故第一数据相比于训练数集而言,训练数据减少了。
而对于本步骤中前M帧图像数据的具体内容,将在下面的实施例中进行描述,此处不再赘述。
103、利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
因为步骤102中的学生模型在训练过程中所利用的训练数据为训练数集中的第一数据,而第一数据相比于训练数集而言,训练数据减少了,为了提升学生模型在推理图像时,推理出图像的质量。
本申请实施例利用步骤101中的老师模型,对学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
具体的,知识蒸馏是一种知识迁移,意在将复杂模型的泛化能力迁移到更简单的模型中,也即本实施例中将老师模型的泛化能力迁移到学生模型中,使得学生模型虽然在训练过程中只利用了训练数集中的第一数据进行训练,但学生模型却可以拥有老师模型的图像推理能力,也即在图像推理时,可以保持和老师模型一样的图像推理准确率。
具体的,对于具体的知识蒸馏过程,将在下面的实施例中进行描述,此处也不再赘述。
本申请实施例中利用老师模型对所述学生模型进行知识蒸馏,使得学生模型在输入数据类型减少时,也可以达到接近老师模型的高准确率的图像翻译质量,从而在提升了学生模型图像推理的准确率。
基于图1所述的实施例,下面对步骤102中目标帧的前M帧图像数据进行说明:
作为一种可选的实施例,目标帧的前M帧图像数据可以是:所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据。
当前M帧图像数据为前M帧图像和前M帧图像的轮廓线数据时,则第一数据为目标帧图像、目标帧图像的轮廓线数据、目标帧图像的前M帧图像和前M帧图像的轮廓线数据,则第一数据总共为(M+1)*4维数据,而因为目标帧图像不参与图像翻译模型的推理训练,而只参加损失函数的计算,故第一数据一起包括4M+1维。
因为学生模型用到的训练数据为4M+1维,相比于训练数集中的8M+5维,总共减少了4M+4维,故学生模型的训练速度较老师模型的训练速度更快。
作为另一种可选的实施例,为了进一步加快学生模型的训练速度,第一数据还可以为降低像素后的前M帧图像,及前M帧图像的轮廓线数据。
当第一数据为降低像素后的前M帧图像和前M帧图像的轮廓线数据时,数据的维度虽然没有发生改变,还为4M+1维,但因为前M帧图像的像素发生了改变,对应的数据的大小也发生了改变,故当第一数据为降低像素后的前M帧图像和前M帧图像的轮廓线数据,较第一数据为前M帧图像和前M帧图像的轮廓线数据而言,学生模型的训练速度更快。
基于图1所述的步骤103,下面对103步骤做详细描述,请参阅图3,本申请实施例中模型训练方法的另一个实施例,包括:
301、将所述老师模型中的判别器作为所述学生模型中的判别器;
在具体的知识蒸馏过程中,将训练好的老师模型的判别器作为学生模型的判别器。
302、根据所述学生模型的损失函数,计算所述学生模型的第一损失;
具体的,本实施例中的图像翻译模型为pix2pix,pix2pixHD和vid2vid中的至少一种,而每一种具体的图像翻译模型,其对应的损失函数也有所不同:
对于pix2pix而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss,及使得输出多样化的GANLoss;
对于pix2pixHD而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、及Content Loss;
对于vid2vid而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、Content Loss、视频Loss和光流Loss;
故学生模型的损失函数,根据图像翻译模型的种类不同也对应不同。
而对应的根据每种模型的损失函数,计算学生模型的第一损失,与现有技术描述的一致,此处也不再赘述。
303、计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;
具体的,本实施例中图像翻译模型中的生成器为编码模型-解码模型结构,即encoder-decoder模型,其中,编码模型和解码模型可以采用CNN,RNN,BiRNN、LSTM等深度学习算法中的任一种,此处不做具体限制。
容易理解的是,所谓编码模型的编码过程,就是将输入数据序列转化成一个固定长度的隐藏变量;解码模型,就是将之前生成的固定长度的隐藏变量再转化成输出数据序列,其中,图4给出了编码模型-解码模型的编码-解码过程的示意图。
具体的,本申请实施例是计算老师模型中的第一隐藏变量和学生模型中第二隐藏变量之间的第二损失,然后根据计算出的第二损失,执行步骤305。
具体的,老师模型中的第一隐藏变量是根据训练数集数据进行计算的,而学生模型中的第二隐藏变量是根据训练数集中的第一数据进行计算的,而第一数据相比于训练数集而言,数据明显减少了,故老师模型中的第一隐藏变量与学生模型中的第二隐藏变量之间会存在一个差量,也即本步骤中的第二损失。
304、计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
此外,本申请实施例还计算老师模型的生成器和学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失。
具体的,老师模型的生成器在生成目标帧图像时,是利用目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像(distancemap)数据,及目标帧图像的前N帧图像、前N帧图像的轮廓线数据和前N帧图像的距离图像(distancemap)数据,去得到目标帧的生成图像;而学生模型的生成器在生成目标帧图像时,是利用目标帧图像、目标帧图像的轮廓线数据、目标帧图像的前M帧图像、前M帧图像的轮廓线数据,去得到目标帧的生成图像。
因为老师模型和学生模型在推理目标帧的生成图像时,所利用的数据不同,故老师模型和学生模型的生成器在输入相同的目标帧图像时,所得到的目标帧的生成图像质量(即生成图像的准确率)也不同。
本步骤即为计算老师模型的生成器和学生模型的生成器在输入相同的目标帧图像时,得到的两个目标帧生成图像之间的第三损失,然后执行步骤305。
305、根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
在将老师模型的判别器确定为学生模型的判别器以后,根据步骤302至305步骤计算出的第一损失、第二损失和第三损失中的至少一项及反向传播算法,对学生模型中生成器的卷积层的权重进行梯度更新。
具体的,根据第一损失、第二损失和第三损失中的至少一项及反向传播算法,对学生模型中生成器的卷积层的权重做梯度更新的过程可以参照步骤101的相关描述,此处不再赘述。
具体的,根据第一损失、第二损失和第三损失中的至少两项对学生模型中卷积层的权重做梯度更新时,是将对应的损失进行叠加,然后根据叠加后的损失,对学生模型中卷积层的权重做梯度更新。
如根据第一损失和第二损失,对学生模型中卷积层的权重做梯度更新时,则是将第一损失和第二损失进行叠加,然后根据叠加后所得到的总损失,对学生模型中卷积层的权重做梯度更新。
本申请实施例中,将老师模型的判别器确定为学生模型的判别器,然后根据学生模型对应的具体图像翻译模型(如pix2pix,pix2pixHD或vid2vid)的损失函数,计算学生模型的第一损失,再根据学生模型和老师模型中的第一隐藏变量计算第二损失,以及学生模型和老师模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失,最后根据第一损失、第二损失和第三损失,对学生模型的生成器中的卷积层的权重进行更新,也即根据老师模型,对学生模型进行知识蒸馏的过程。
这样,学生模型既可以因为训练数据量少,提升模型的训练速度,又可以通过知识蒸馏学习到老师模型生成图像的能力,即达到接近老师模型的高准确率的图像翻译质量。
进一步,在执行步骤305时,若根据第一损失、第二损失和第三损失中的至少一个损失,对学生模型中卷积层的权重做梯度更新时,还可以执行以下步骤,以实现对学生模型的不同调节,具体请参阅图5,本申请实施例中模型训练方法的另一个实施例,包括:
501、获取所述第一损失、所述第二损失和所述第三损失对应的权重;
本申请实施例,还可以对第一损失、第二损失和第三损失预先设置一定的权重,然后根据预先设置的权重执行步骤502。
502、根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;
得到第一损失、第二损失和第三损失以及对应的权重后,根据根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失。
假设第一损失为A,第一损失的权重为20%,第二损失为B,第二损失的权重为60%,第三损失为C,第三损失的权重为20%,则目标损失为0.2A+0.6B+0.2C。
503、根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
得到目标损失后,根据目标损失和反向传播算法,对学生模型中生成器的卷积层的权重进行梯度更新。
具体的,对学生模型中生成器的卷积层的权重进行梯度更新的过程与步骤101中描述的类似,此处不再赘述。
本申请实施例中,可以为不同的损失,设置不同的权重,从而计算得到目标损失,最后根据目标损失和反向传播算法,对学生模型中生成器的卷积层的权重进行梯度更新,从而实现了对学生模型在不同损失方向上的倾向性修正。
上面对本申请实施例中的模型训练方法做了描述,下面接着对本申请实施例中的模型训练装置进行描述,请参阅图6,本申请实施例中描述训练装置的一个实施例,包括:
第一训练单元601,用于利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为GAN网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的distancemap数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的distancemap数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧图像以外的任意一帧或任意多帧图像;
第二训练单元602,用于利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据中的部分数据,其中,所述M为大于等于1且小于等于N的整数;
知识蒸馏单元603,用于利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
优选的,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据。
优选的,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧降低像素后的图像和所述前M帧图像的轮廓线数据。
优选的,所述图像翻译模型中的生成器为编码模型-解码模型结构,所述知识蒸馏单元603,包括:
设置模块6031,用于将所述老师模型中的判别器作为所述学生模型中的判别器;
第一计算模块6032,用于根据所述学生模型的损失函数,计算所述学生模型的第一损失;
第二计算模块6033,用于计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;
第三计算模块6034,用于计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
更新模块6035,用于根据所述第一损失、所述第二损失和所述第三损失中的至少一项及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
具体的,所述更新模块6035具体包括:
获取子模块60351,用于获取所述第一损失、所述第二损失和所述第三损失对应的权重;
计算子模块60352,用于根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;
更新子模块60353,用于根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
优选的,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。
本申请实施例中,通过第一训练单元601利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为GAN网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的distancemap数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的distancemap数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;通过第二训练单元602利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;通过知识蒸馏单元603利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
本申请实施例中利用老师模型对所述学生模型进行知识蒸馏,使得学生模型在输入数据类型减少时,也可以达到接近老师模型的高准确率的图像翻译质量。
上面从模块化功能实体的角度对本发明实施例中的模型训练装置进行了描述,下面从硬件处理的角度对本发明实施例中的计算机装置进行描述:
该计算机装置用于实现模型训练装置的功能,本发明实施例中计算机装置一个实施例包括:
处理器以及存储器;
存储器用于存储计算机程序,处理器用于执行存储器中存储的计算机程序时,可以实现如下步骤:
利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;
利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;
利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
将所述老师模型中的判别器作为所述学生模型中的判别器;
根据所述学生模型的损失函数,计算所述学生模型的第一损失;
计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;
计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
获取所述第一损失、所述第二损失和所述第三损失对应的权重;
根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;
根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
可以理解的是,上述说明的计算机装置中的处理器执行所述计算机程序时,也可以实现上述对应的各装置实施例中各单元的功能,此处不再赘述。示例性的,所述计算机程序可以被分割成一个或多个模块/单元,所述一个或者多个模块/单元被存储在所述存储器中,并由所述处理器执行,以完成本发明。所述一个或多个模块/单元可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述所述计算机程序在所述模型训练装置中的执行过程。例如,所述计算机程序可以被分割成上述模型训练装置中的各单元,各单元可以实现如上述相应模型训练装置说明的具体功能。
所述计算机装置可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述计算机装置可包括但不仅限于处理器、存储器。本领域技术人员可以理解,处理器、存储器仅仅是计算机装置的示例,并不构成对计算机装置的限定,可以包括更多或更少的部件,或者组合某些部件,或者不同的部件,例如所述计算机装置还可以包括输入输出设备、网络接入设备、总线等。
所述处理器可以是中央处理单元(Central Processing Unit,CPU),还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable GateArray,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等,所述处理器是所述计算机装置的控制中心,利用各种接口和线路连接整个计算机装置的各个部分。
所述存储器可用于存储所述计算机程序和/或模块,所述处理器通过运行或执行存储在所述存储器内的计算机程序和/或模块,以及调用存储在存储器内的数据,实现所述计算机装置的各种功能。所述存储器可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序等;存储数据区可存储根据终端的使用所创建的数据等。此外,存储器可以包括高速随机存取存储器,还可以包括非易失性存储器,例如硬盘、内存、插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(SecureDigital,SD)卡,闪存卡(Flash Card)、至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
本发明还提供了一种计算机可读存储介质,该计算机可读存储介质用于实现模型训练装置的功能,其上存储有计算机程序,计算机程序被处理器执行时,处理器,可以用于执行如下步骤:
利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;
利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;
利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
将所述老师模型中的判别器作为所述学生模型中的判别器;
根据所述学生模型的损失函数,计算所述学生模型的第一损失;
计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;
计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
获取所述第一损失、所述第二损失和所述第三损失对应的权重;
根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;
根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
可以理解的是,所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在相应的一个计算机可读取存储介质中。基于这样的理解,本发明实现上述相应的实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(ROM,Read-OnlyMemory)、随机存取存储器(RAM,Random Access Memory)、电载波信号、电信信号以及软件分发介质等。需要说明的是,所述计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
以上所述,以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。

Claims (10)

1.一种模型训练方法,其特征在于,所述方法包括:
利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;
利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;
利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
2.根据权利要求1所述的方法,其特征在于,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据。
3.根据权利要求1所述的方法,其特征在于,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧降低像素后的图像和所述前M帧图像的轮廓线数据。
4.根据权利要求1-3中任一项所述的方法,其特征在于,所述图像翻译模型中的生成器为编码模型-解码模型结构,所述利用所述老师模型对所述学生模型进行知识蒸馏,包括:
将所述老师模型中的判别器作为所述学生模型中的判别器;
根据所述学生模型的损失函数,计算所述学生模型的第一损失;
计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;
计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
5.根据权利要求4所述的方法,其特征在于,所述根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新,包括:
获取所述第一损失、所述第二损失和所述第三损失对应的权重;
根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;
根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
6.根据权利要求1所述的方法,其特征在于,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。
7.一种模型训练装置,其特征在于,所述装置包括:
第一训练单元,用于利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;
第二训练单元,用于利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;
知识蒸馏单元,用于利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。
8.根据权利要求7所述的模型训练装置,其特征在于,所述图像翻译模型中的生成器为编码模型-解码模型结构,所述知识蒸馏单元具体用于:
将所述老师模型中的判别器作为所述学生模型中的判别器;
根据所述学生模型的损失函数,计算所述学生模型的第一损失;
计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型中编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型中编码模型与解码模型之间的隐藏变量;
计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。
9.一种计算机装置,包括处理器,其特征在于,所述处理器在执行存储于存储器上的计算机程序时,用于实现如权利要求1至6中任一项所述的模型训练方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时,用于实现如权利要求1至6中任一项所述的模型训练方法。
CN202110495293.5A 2021-05-07 2021-05-07 一种模型训练方法及模型训练装置 Active CN113160041B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110495293.5A CN113160041B (zh) 2021-05-07 2021-05-07 一种模型训练方法及模型训练装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110495293.5A CN113160041B (zh) 2021-05-07 2021-05-07 一种模型训练方法及模型训练装置

Publications (2)

Publication Number Publication Date
CN113160041A true CN113160041A (zh) 2021-07-23
CN113160041B CN113160041B (zh) 2024-02-23

Family

ID=76873720

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110495293.5A Active CN113160041B (zh) 2021-05-07 2021-05-07 一种模型训练方法及模型训练装置

Country Status (1)

Country Link
CN (1) CN113160041B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115085805A (zh) * 2022-06-09 2022-09-20 南京信息工程大学 一种基于对抗蒸馏模型的少模多芯光纤光性能监测方法、系统、装置及存储介质

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200134506A1 (en) * 2018-10-29 2020-04-30 Fujitsu Limited Model training method, data identification method and data identification device
CN111160533A (zh) * 2019-12-31 2020-05-15 中山大学 一种基于跨分辨率知识蒸馏的神经网络加速方法
CN111950302A (zh) * 2020-08-20 2020-11-17 上海携旅信息技术有限公司 基于知识蒸馏的机器翻译模型训练方法、装置、设备及介质
CN111967573A (zh) * 2020-07-15 2020-11-20 中国科学院深圳先进技术研究院 数据处理方法、装置、设备及计算机可读存储介质
CN112508120A (zh) * 2020-12-18 2021-03-16 北京百度网讯科技有限公司 学生模型训练方法、装置、设备、介质和程序产品

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200134506A1 (en) * 2018-10-29 2020-04-30 Fujitsu Limited Model training method, data identification method and data identification device
CN111160533A (zh) * 2019-12-31 2020-05-15 中山大学 一种基于跨分辨率知识蒸馏的神经网络加速方法
CN111967573A (zh) * 2020-07-15 2020-11-20 中国科学院深圳先进技术研究院 数据处理方法、装置、设备及计算机可读存储介质
CN111950302A (zh) * 2020-08-20 2020-11-17 上海携旅信息技术有限公司 基于知识蒸馏的机器翻译模型训练方法、装置、设备及介质
CN112508120A (zh) * 2020-12-18 2021-03-16 北京百度网讯科技有限公司 学生模型训练方法、装置、设备、介质和程序产品

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115085805A (zh) * 2022-06-09 2022-09-20 南京信息工程大学 一种基于对抗蒸馏模型的少模多芯光纤光性能监测方法、系统、装置及存储介质
CN115085805B (zh) * 2022-06-09 2024-03-19 南京信息工程大学 一种基于对抗蒸馏模型的光纤光性能监测方法及系统

Also Published As

Publication number Publication date
CN113160041B (zh) 2024-02-23

Similar Documents

Publication Publication Date Title
Foster Generative deep learning
Jaafra et al. Reinforcement learning for neural architecture search: A review
US9619749B2 (en) Neural network and method of neural network training
Zhong et al. Self-adaptive neural module transformer for visual question answering
CN107358626A (zh) 一种利用条件生成对抗网络计算视差的方法
US11354792B2 (en) System and methods for modeling creation workflows
CN111176758B (zh) 配置参数的推荐方法、装置、终端及存储介质
KR102602112B1 (ko) 얼굴 이미지 생성을 위한 데이터 프로세싱 방법 및 디바이스, 및 매체
CN107169573A (zh) 利用复合机器学习模型来执行预测的方法及系统
JPH06509669A (ja) 改良ニューラル・ネットワーク
CN108121995A (zh) 用于识别对象的方法和设备
Ma et al. Learning and exploring motor skills with spacetime bounds
WO2023174036A1 (zh) 联邦学习模型训练方法、电子设备及存储介质
CN105701540A (zh) 一种自生成神经网络构建方法
CN113633983A (zh) 虚拟角色表情控制的方法、装置、电子设备及介质
CN116363308A (zh) 人体三维重建模型训练方法、人体三维重建方法和设备
CN116188621A (zh) 基于文本监督的双向数据流生成对抗网络图像生成方法
CN113160041A (zh) 一种模型训练方法及模型训练装置
CN115526223A (zh) 潜在空间中的基于得分的生成建模
KR102470866B1 (ko) 3d 캐릭터의 얼굴 표정 리타게팅 방법 및 이를 위해 신경망을 학습하는 방법
CN116704079B (zh) 图像生成方法、装置、设备及存储介质
CN113077383B (zh) 一种模型训练方法及模型训练装置
AU2022241513B2 (en) Transformer-based shape models
US20230237725A1 (en) Data-driven physics-based models with implicit actuations
CN113112400A (zh) 一种模型训练方法及模型训练装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant