CN113112400B - 一种模型训练方法及模型训练装置 - Google Patents
一种模型训练方法及模型训练装置 Download PDFInfo
- Publication number
- CN113112400B CN113112400B CN202110496339.5A CN202110496339A CN113112400B CN 113112400 B CN113112400 B CN 113112400B CN 202110496339 A CN202110496339 A CN 202110496339A CN 113112400 B CN113112400 B CN 113112400B
- Authority
- CN
- China
- Prior art keywords
- image
- frame image
- model
- data
- translation model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 177
- 238000000034 method Methods 0.000 title claims abstract description 70
- 238000013519 translation Methods 0.000 claims abstract description 345
- 230000006870 function Effects 0.000 claims description 77
- 238000004422 calculation algorithm Methods 0.000 claims description 46
- 238000004590 computer program Methods 0.000 claims description 20
- 230000004048 modification Effects 0.000 claims description 5
- 238000012986 modification Methods 0.000 claims description 5
- 238000004088 simulation Methods 0.000 abstract description 21
- 230000008569 process Effects 0.000 description 19
- 238000010586 diagram Methods 0.000 description 16
- 238000004364 calculation method Methods 0.000 description 13
- 238000013528 artificial neural network Methods 0.000 description 12
- 238000012545 processing Methods 0.000 description 5
- 238000013135 deep learning Methods 0.000 description 4
- 210000002569 neuron Anatomy 0.000 description 4
- 230000003287 optical effect Effects 0.000 description 4
- 238000006243 chemical reaction Methods 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 230000009467 reduction Effects 0.000 description 3
- 230000004913 activation Effects 0.000 description 2
- 230000008901 benefit Effects 0.000 description 2
- 230000008859 change Effects 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 238000003491 array Methods 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T3/00—Geometric image transformations in the plane of the image
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T9/00—Image coding
- G06T9/002—Image coding using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
本发明实施例的模型训练方法及训练装置,用于提升对虚拟数字人姿态的仿真速度。本发明实施例方法包括:利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,生成器为编码模型‑解码模型结构,编码模型采用的是残差网络架构,第一数据包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据、目标帧图像的前N帧图像、前N帧图像的轮廓线数据和距离图像数据,其中,N为大于等于2的整数,目标帧图像为训练集中除了第一帧图像和第二帧图像以外的任一帧或任意多帧图像;将第一图像翻译模型的编码模型的残差网络架构修改为轻量模型架构,得到第二图像翻译模型。
Description
技术领域
本发明涉及图像翻译技术领域,尤其涉及一种模型训练方法及模型训练装置。
背景技术
所谓图像翻译,指从一副图像到另一副图像的转换。可以类比机器翻译,将一种语言转换为另一种语言。
现有技术中较为经典的图像翻译模型有pix2pix,pix2pixHD,vid2vid。pix2pix提出了一个统一的框架解决了各类图像翻译问题,pix2pixHD则在pix2pix的基础上,较好地解决了高分辨率图像转换(翻译)的问题,vid2vid则在pix2pixHD的基础上,较好地解决了高分辨率的视频转换问题。
数字人,是一种利用信息科学的方法对人体在不同水平的形态和功能进行的虚拟仿真。而目前的图像翻译模型,可以对图像中的数字人姿态进行虚拟仿真,但现有技术中的图像翻译模型,因为模型架构复杂,从而导致在训练过程中的数据运算量大,进而使得图像翻译模型的图像翻译速度较慢,也即数字人姿态仿真的速度较慢。
发明内容
本发明实施例提供了一种模型训练方法及训练装置,用于提升图像翻译模型对虚拟数字人姿态的仿真速度。
本申请实施例第一方面提供了一种模型训练方法,包括:
利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器为编码模型-解码模型结构,所述编码模型采用的是残差网络架构,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像;
将所述第一图像翻译模型的所述编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
优选的,所述方法还包括:
将所述第二图像翻译模型的生成器中的编码模型首层中的大卷积算子修改为预设数量的小卷积算子,以得到第三图像翻译模型,其中,所述预设数量的小卷积算子和所述大卷子算子在输入相同的输入数据时,所述预设数量的小卷积算子对所述输入数据的数据运算量较小。
优选的,所述方法还包括:
利用所述训练集中的第二数据对所述第三图像翻译模型中的生成器进行训练,其中,所述第二数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前N帧图像和所述前N帧图像的轮廓线数据;
根据所述第三图像翻译模型的损失函数,计算所述第三图像翻译模型的第一损失;
根据所述第一损失和反向传播算法,对所述第三图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第四图像翻译模型。
优选的,所述方法还包括:
利用所述训练集中的第三数据对所述第四图像翻译模型中的生成器进行训练,其中,所述第三数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据,其中,所述M为大于等于1且小于N的整数;
根据所述第四图像翻译模型的损失函数,计算所述第四图像翻译模型的第二损失;
根据所述第二损失和反向传播算法,对所述第四图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第五图像翻译模型。
优选的,所述方法还包括:
利用所述训练集中的第四数据对所述第五图像翻译模型中的生成器进行训练,其中,所述第四数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据;
根据所述第五图像翻译模型的损失函数,计算所述第五图像翻译模型的第三损失;
根据所述第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
优选的,所述方法还包括:
利用所述训练集中的第五数据对所述第六图像翻译模型中的生成器进行训练,其中,所述第五数据包括所述目标帧图像、所述目标帧图像的轮廓线数据,降低像素后的所述第一帧图像和所述第一帧图像的轮廓线数据;
根据所述第六图像翻译模型的损失函数,计算所述第六图像翻译模型的第四损失;
根据所述第四损失和反向传播算法,对所述第六图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第七图像翻译模型。
优选的,所述轻量模型架构包括:
所述轻量模型架构包括MobileNet架构、ShuffleNet架构、SqueezeNet架构和Xception架构中的至少一种。
优选的,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。
本申请实施例第二方面提供了一种模型训练装置,包括:
第一训练单元,用于利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器为编码模型-解码模型结构,所述编码模型采用的是残差网络架构,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像;
第一修改单元,用于将所述第一图像翻译模型的所述编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
优选的,所述装置还包括:
第二修改单元,用于将所述第二图像翻译模型的生成器中的编码模型首层中的大卷积算子修改为预设数量的小卷积算子,以得到第三图像翻译模型,其中,所述预设数量的小卷积算子和所述大卷子算子在输入相同的输入数据时,所述预设数量的小卷积算子对所述输入数据的数据运算量较小。
优选的,所述装置还包括:
第二训练单元,用于利用所述训练集中的第二数据对所述第三图像翻译模型中的生成器进行训练,其中,所述第二数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前N帧图像和所述前N帧图像的轮廓线数据;
第一计算单元,用于根据所述第三图像翻译模型的损失函数,计算所述第三图像翻译模型的第一损失;
第一更新单元,用于根据所述第一损失和反向传播算法,对所述第三图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第四图像翻译模型。
优选的,所述装置还包括:
第三训练单元,用于利用所述训练集中的第三数据对所述第四图像翻译模型中的生成器进行训练,其中,所述第三数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据,其中,所述M为大于等于1且小于N的整数;
第二计算单元,用于根据所述第四图像翻译模型的损失函数,计算所述第四图像翻译模型的第二损失;
第二更新单元,用于根据所述第二损失和反向传播算法,对所述第四图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第五图像翻译模型。
优选的,所述装置还包括:
第四训练单元,用于利用所述训练集中的第四数据对所述第五图像翻译模型中的生成器进行训练,其中,所述第四数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据;
第三计算单元,用于根据所述第五图像翻译模型的损失函数,计算所述第五图像翻译模型的第三损失;
第三更新单元,用于根据所述第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
优选的,所述装置还包括:
第五训练单元,用于利用所述训练集中的第五数据对所述第六图像翻译模型中的生成器进行训练,其中,所述第五数据包括所述目标帧图像、所述目标帧图像的轮廓线数据,降低像素后的所述第一帧图像和所述第一帧图像的轮廓线数据;
第四计算单元,用于根据所述第六图像翻译模型的损失函数,计算所述第六图像翻译模型的第四损失;
第四更新单元,用于根据所述第四损失和反向传播算法,对所述第六图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第七图像翻译模型。
优选的,所述轻量模型架构包括:
所述轻量模型架构包括MobileNet架构、ShuffleNet架构、SqueezeNet架构和Xception架构中的至少一种。
优选的,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。
本申请实施例还提供了一种计算机装置,包括处理器,所述处理器在执行存储于存储器上的计算机程序时,用于实现本申请实施例第一方面提供的模型训练方法。
本申请实施例还提供了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时,用于实现本申请实施例第一方面提供的模型训练方法。
从以上技术方案可以看出,本发明实施例具有以下优点:
本申请实施例中,利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器采用的是编码-解码模型结构,所述编码模型采用的是残差网络架构,所述训练数集包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的任意一帧或任意多针图像;将所述第一图像翻译模型生成器中编码模型的残差网络架构修改为轻量模型架构,并将修改后的第一图像翻译模型作为第二图像翻译模型。
本申请实施例所得到的第二图像翻译模型,其生成器中的编码模型为轻量模型架构,因为轻量模型架构相比于残差网络架构而言,对同样输入数据的运算量明显减少,从而提升了图像翻译的速度,也即提升了对数字人姿态的仿真速度。
附图说明
图1为本申请实施例一种模型训练方法的一个实施例示意图;
图2为本申请实施例中残差网络架构(RestnetBlock)的结构示意图;
图3为本申请实施例中MobileNet架构的结构示意;
图4为本申请实施例一种模型训练方法的另一个实施例示意图;
图5为本申请实施例中第一图像翻译模型的生成器和第三图像翻译模型的生成器结构之间的对比示意图;
图6为本申请实施例一种模型训练方法的另一个实施例示意图;
图7为本申请实施例中神经网络结构的示意图;
图8为本申请实施例一种模型训练方法的另一个实施例示意图;
图9为本申请实施例一种模型训练方法的另一个实施例示意图;
图10为本申请实施例一种模型训练方法的另一个实施例示意图;
图11为本申请实施例一种模型训练装置的一个实施例示意图。
具体实施方式
本发明实施例提供了一种模型训练方法及装置,用于减少模型的运算量,以提升模型推理的速度,即提升模型生成图像帧的速度,使得该图像翻译模型用于数字人姿态仿真时,也可以提升数字人姿态的仿真速度。
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的实施例能够以除了在这里图示或描述的内容以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
基于现有技术中,目前的图像翻译模型在对图像中的数字人姿态进行虚拟仿真时,因为模型架构复杂,而导致模型对输入数据的运算量大,进而导致图像翻译模型在对数字人姿态进行虚拟仿真时出现的仿真速度慢的问题,本申请提出了一种模型训练方法及训练装置,用于提升图像翻译模型对数字人姿态仿真的速度。
为方便理解,下面对本申请中的模型训练方法进行说明,请参阅图1,本申请中模型训练方法的一个实施例,包括:
101、利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器为编码模型-解码模型结构,所述编码模型采用的是残差网络架构,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像;
具体的,现有技术中较为经典的图像翻译模型有pix2pix,pix2pixHD,vid2vid,其中,各图像翻译模型采用的都是GAN网络模型,其中,GAN网络应用到深度学习神经网络上来说,就是通过生成器G(Generator)和判别器D(Discriminator)不断博弈,进而使G学习到数据的分布,如果用到图片生成上,则训练完成后,G可以从一段随机数中生成逼真的图像。
目前,现有技术中图像翻译模型的生成器G采用的是编码模型-解码模型,即encoder-decoder模型,其中,编码模型和解码模型可以采用CNN,RNN,BiRNN、LSTM等深度学习算法中的任一种,现在的深度学习算法当网络很深的时候,模型效果却越来越差了,且通过实验可以发现:随着网络层级的不断增加,模型精度不断得到提升,而当网络层级增加到一定的数目以后,训练精度和测试精度迅速下降,这说明当网络变得很深以后,深度网络就变得更加难以训练了,所以为了减小误差,现有的深度学习算法通过采用残差网络架构(RestNetBlock)来保持模型的精准度,其中,图2给出了残差网络结构(RestNetBlock)的示意图,通过这种残差网络结构,可以在网络层很深时,还可以保持模型的精准度,本申请实施中第一图像翻译模型中的生成模型即采用了RestNet网络模型。
进一步,本申请中采用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像(distancemap)数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像(distancemap)数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像。
假设训练集中任一帧图像的大小为X*X*3,其中,X为图片的像素大小,而3为图片通道的个数,其中3通道代表图片为RGB图像,则第一数据中目标帧图像的前N帧图像一起为N*3维数据,而训练集中任一帧图像的distancemap数据为4维数据,则目标帧图像及目标帧图像的前N帧图像的distancemap数据一起为(N+1)*4维数据,而训练集中任一帧图像的轮廓线数据为1维数据,则目标帧图像及目标帧图像的前N帧图像的轮廓线数据为N+1维数据,则第一数据的维度一起为8N+5维。
因为第一数据为目标帧图像及目标帧前N帧的图像,而N大于等于2,所以目标帧图像为训练集中除了第一帧及第二帧以外的任意一帧或任意多帧图像。
102、将所述第一图像翻译模型的所述编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
当第一图像翻译模型中编码模型的RestNet模型的层数很大时,第一图像翻译模型就会面临模型过于庞大,而导致的数据运算量大及高延时的问题,也即利用第一图像翻译模型在实现数字人姿态仿真时,会出现仿真速度慢、仿真延时大的问题。
针对该问题,本申请实施例将第一图像翻译模型的编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
具体的,本实施例中的轻量模型架构,包括MobileNet架构、ShuffleNet架构、SqueezeNet架构和Xception架构中的至少一种。
为方便理解,下面以MobileNet架构为例,对编码模型由RestNet架构修改为MobileNet架构以后,数据运算量的变化进行说明:
其中,图3给出了MobileNet架构的示意图,在图3中,假设输入图像、输出图像和特征图(featureMap)的大小都为M*M*N,则在图3的示意图中,MobileNet架构的计算量为:
M*M*N*3*3+1*1*N*N*M*M+M*M*N*1*1*N*2=M*M*N*(N+9+2)
在图2中,则RestNet架构的计算量为:
M*M*N*1*1*N+M*M*N*3*3*N+M*M*N*1*1*N*2=M*M*N*N*(1+9+2)
则MobileNet架构的计算量与RestNet架构的计算量的比值为:(11+N)/12N;
由此可知,N的取值越大,计算量节省的越大,随着卷积核个数的增加,即通道数变多,MobileNet架构的计算量要比RestNet架构的计算量小的多。
至于其他轻量模型架构的运算量变化原理,在现有技术中都有详细描述,此处不再赘述。
本申请实施例中,利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器采用的是编码-解码模型结构,所述编码模型采用的是残差网络架构,所述训练数集包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的任意一帧或任意多针图像;将所述第一图像翻译模型生成器中编码模型的残差网络架构修改为轻量模型架构,并将修改后的第一图像翻译模型作为第二图像翻译模型。
因为本申请实施例所得到的第二图像翻译模型,其生成器中的编码模型为轻量模型架构,而轻量模型架构相比于残差网络架构而言,对同样输入数据的运算量明显减少,从而提升了图像翻译的速度,也即提升了对数字人姿态的仿真速度。
基于在图1所述的实施例中,第二图像翻译模型生成器的编码模型的首层一般为一个7*7的卷积层,用于提取输入图像的特征,而为了减少首层中的卷积核与输入图像之间的运算量,还可以执行以下步骤,请参阅图4,本申请实施例中模型训练方法的另一个实施例,包括:
401、将所述第二图像翻译模型的生成器中的编码模型首层中的大卷积算子修改为预设数量的小卷积算子,以得到第三图像翻译模型,其中,所述预设数量的小卷积算子和所述大卷子算子在输入相同的输入数据时,所述预设数量的小卷积算子对所述输入数据的数据运算量较小。
基于图1所述的实施例,为了减少第二图像翻译模型的生成器中的编码模型首层中的卷积核与输入图像之间的计算量,还可以将编码模型首层中的大卷积算子修改为预设数量的小卷积算子,以得到第三图像翻译模型,其中,预设数量的小卷子算子和大卷子算子在输入相同的输入数据时,预设数量的小卷子算子对输入数据的数据数据运算量较小。
为方便理解,下面举例说明:
一般的,第二图像翻译模型的编码模型的首层为一个7*7的卷积层,为了减少该卷积层与输入图像之间的运算量,可以将该7*7的卷积层修改为2个3*3的小卷积核,且保证输入和输入的维度保持不变,这样,大卷积算子和预设数量的小卷积算子在对相同的输入图像(假设为N*N)进行运算时,它们之间运算量的比值为:(N*N*7*7)/(N*N*2*3*3)=2.72,从而使得计算量减少了2.72倍。
进一步,假设编码模型的首层有2个7*7的卷积核时,则上述运算量则减少了5.44倍。
由此可知,本申请实施例中将第二图像翻译模型的生成器中的编码模型首层中的大卷积算子修改为预设数量的小卷积算子,可以减少数据的运算量,且保证输入和输出图像的维度保持不变,从而进一步节省了数据的运算量,提升了第三图像翻译模型对虚拟数字人的仿真速度。
为方便理解,图5给出了第一图像翻译模型的生成器和第三图像翻译模型的生成器之间的对比示意图。
在图5的示意图中,将第一图像翻译模型的生成器中的编码模型首层中的7*7的卷积算子修改为2个3*3的卷积算子,且将编码模型中的RestNet架构修改为MobileNet架构,不仅节省了编码模型首层对输入图像的运算量,还通过MobileNet架构进一步节省了对输入图像的运算量,从而提升了第三图像翻译模型对虚拟数字人的仿真速度。
基于图4得到的第三图像翻译模型,在训练过程中,若还采用第一数据进行训练,则会因为训练数据量大,导致第三图像翻译模型在实现数字人姿态仿真时,需要的输入数据较多,从而导致仿真速度较慢,针对该问题,还可以采用训练集中的第二数据对第三图像翻译模型进行训练,以得到第四图像翻译模型,具体请参阅图6,本申请实施例中模型训练方法的另一个实施例,包括:
601、利用所述训练集中的第二数据对所述第三图像翻译模型中的生成器进行训练,其中,所述第二数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前N帧图像和所述前N帧图像的轮廓线数据;
若采用第一数据对第三图像翻译模型进行训练,根据图1所述的实施例,第一数据所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像(distancemap)数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像(distancemap)数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像。
假设训练集中任一帧图像的大小为X*X*3,其中,X为图片的像素大小,而3为图片通道的个数,其中3通道代表图片为RGB图像,则第一数据中目标帧图像及目标帧图像的前N帧图像一起为N*3维数据,而训练集中任一帧图像的distancemap数据为4维数据,则目标帧图像及目标帧图像的前N帧图像的distancemap数据一起为(N+1)*4维数据,而训练集中任一帧图像的轮廓线数据为1维数据,则目标帧图像及目标帧图像的前N帧图像的轮廓线数据为N+1维数据,则第一数据的维度一起为8N+5维。
而为了节省模型对输入数据的运算量,本申请实施例还可以采用第二数据对第三图像翻译模型中的生成器进行训练,其中,第二数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前N帧图像和所述前N帧图像的轮廓线数据。
假设训练集中任一帧图像的大小为X*X*3,其中,X为图片的像素大小,而3为图片通道的个数,其中3通道代表图片为RGB图像,则第一数据中目标帧图像及目标帧图像的前N帧图像一起为N*3维数据,而目标帧图像的轮廓线数据及目标帧图像的前N帧图像的轮廓线数据,一起为(N+1)*1维数据,则第二数据一起为4N+1维,这样,第二数据的维度相比于第一数据的维度则少了4N+4,故通过本申请实施例训练后的第四图像翻译模型,相比于第三图像翻译模型而言,图像推理的速度更快,也即实现虚拟数字人姿态仿真时,仿真的速度也更快。
602、根据所述第三图像翻译模型的损失函数,计算所述第三图像翻译模型的第一损失;
具体的,本申请实施例中的第三图像翻译模型包括pix2pix,pix2pixHD和vid2vid中的至少一种,对于每一种图像具体的翻译模型,其对应的损失函数也有所不同:
对于pix2pix而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss,及使得输出多样化的GANLoss;
对于pix2pixHD而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、及Content Loss;
对于vid2vid而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、Content Loss、视频Loss和光流Loss;
对于每一种Loss函数在现有技术中都有详细描述,此处不再赘述。
得到第三图像翻译模型的损失函数后,根据具体的损失函数,计算第三图像翻译模型的第一损失。
603、根据所述第一损失和反向传播算法,对所述第三图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第四图像翻译模型。
根据第一损失和反向传播算法,对第三图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第四图像翻译模型。
为方便理解梯度更新的过程,先对GAN网络中的生成器做简单描述:
图像翻译模型中的生成器采用的是神经网络算法,而多层感知器(Multi-LayerPerceptron,MLP)也叫人工神经网络(Artificial Neural Network,ANN),一般包括输入层、输出层,及设于输入层和输出层之间的多个隐层。最简单的MLP需要有一层隐层,即输入层、隐层和输出层才能称为一个简单的神经网络。
接下来以图7中的神经网络为例,对数据的传导过程进行描述:
1、神经网络的前向输出
其中,第0层(输入层),我们将x1、x2、x3向量化为X;
0层和1层(隐层)之间,存在权重w1、w2、w3,将权重向量化为W[1],其中W[1]表示第一层的权重;
0层和1层(隐层)之间,还存在偏置b1、b2、b3,将其向量化为b[1],其中b[1]表示第一层的权重;
对于第1层,计算公式为:
Z[1]=W[1]X+b[1];
A[1]=sigmoid(Z[1]);
其中,Z为输入值的线型组合,A为Z通过激活函数sigmoid的值,对于第一层的输入值X,输出值为A,也是下一层的输入值,在sigmoid激活函数中,其取值在[0,1]之间,可以将其理解为一个阀,就像人的神经元一样,当一个神经元受到刺激,并不是立刻感觉到,而是这个刺激超声了阀值,才会让神经元向上级传播。
1层和2层(输出层)之间,与0层和1层之间类似,其计算公式如下:
Z[2]=W[2]X+b[2]
A[2]=sigmoid(Z[2])
yhat=A[2];
其中yhat即为本次神经网络的输出值。
2、损失函数
在神经网络训练的过程中,一般通过损失函数来衡量这个神经网络是否训练到位。
一般情况下,我们选择如下函数作为作为损失函数:
其中,y为图片的真实特征值,yhat为生成图片的特征值;
当y=1时,若yhat越接近1,越接近0,表示预测效果越好,当损失函数达到最小值时,说明生成模型所生成的当前帧的生成图像越接近于当前帧的原始图像。
3、反向传播算法
在上述神经网络模型中,可以通过计算损失函数来得到神经网络的训练效果,同时还可以通过反向传播算法,来更新参数,使得神经网络模型更能得到我们想要的预测值。其中,梯度下降算法即为一种优化权重W和偏置b的方法。
具体的,梯度下降算法是对损失函数求偏导数,然后用偏导数来更新w1、w2和b。
为方便理解,我们将损失函数公式化为以下的公式:
z=w1x1+w2x2+b;
然后,分别将对α和z求导:
再对w1、w2和b求导:
接下来用梯度下降算法更新权重参数w和偏置参数b:
其中,w1:=w1-βdw1
w2:=w2-βdw2
b:=b-βdb。
其中,β表示学习率,即学习的步长,在实际训练过程中,如果学习率过大,则会在最优解附近来回震荡,无法达到最优解,如果学习率过小,则可能需要很多次的迭代,才能达到最优解,所以在实际训练过程中,学习率也是一个很重要的选择参数。
而本申请中对生成器的训练过程,即根据第三图像翻译模型中的损失函数计算对应损失,然后利用反向传播算法,对生成器中卷积层的权重进行更新的过程,而具体的更新过程可以参阅上述损失函数和反向传播算法的计算过程。
因为本申请实施例采用第二数据对第三图像翻译模型进行训练,得到第四图像翻译模型,使得第四图像翻译模型相比于第三图像翻译模型而言,在对图像进行推理时,所需要的输入数据量更少,即从原来的8N+8维数据,减少为4N+4维数据,从而提升了图像推理的速度,也即提升了对数字虚拟人姿态的仿真速度。
基于图6实施例所得到的第四图像翻译模型,为了加快第四图像翻译模型的推理速度,还可以对第四图像翻译模型进行训练,得到第五图像翻译模型,下面对第五图像翻译模型的训练过程进行描述,请参阅图8,本申请实施例中模型训练方法的另一个实施例,包括:
801、利用所述训练集中的第三数据对所述第四图像翻译模型中的生成器进行训练,其中,所述第三数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据,其中,所述M为大于等于1且小于N的整数;
在实际训练过程中,为了加快第四图像翻译模型的推理速度,也即数字人姿态的仿真速度,本申请实施例采用训练集中的第三数据对第四图像翻译模型中的生成器进行训练,其中,所述第三数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据,其中,所述M为大于等于1且小于N的整数。
为方便理解,下面对第二数据和第三数据的区别进行说明:
假设N=2,则M=1;
则第二数据为目标帧图像及目标帧图像的前2帧图像、目标帧图像的轮廓线数据以及目标帧图像的前2帧图像的轮廓线数据,假设训练集中每一帧图像的大小为X*X*3,其中,X为图片的像素大小,而3为图片通道的个数,其中3通道代表图片为RGB图像,则第二数据总的数据维度为2*3+3*1=9维;而第三数据为目标帧图像及目标帧图像的前1帧图像、目标帧图像的轮廓线数据以及目标帧图像的前1帧图像的轮廓线数据,则第三数据总的数据维度为1*3+2*1=5维。
也就是说,在使用第二数据对第三图像翻译模型进行训练的过程中,是利用目标帧的轮廓线数据及目标帧图像的前2帧图像、以及目标帧图像的前2帧图像的轮廓线数据,推理得到目标帧的生成图像,而利用第三数据对第四图像翻译模型进行训练的过程中,是利用目标帧的轮廓线数据及目标帧图像的前1帧图像、以及目标帧图像的前1帧图像的轮廓线数据,推理得到目标帧的生成图像,其中,训练过程中的目标帧图像是用于计算第三图像翻译模型或第四图像翻译模型的损失,而不参与具体的图像推理过程。
802、根据所述第四图像翻译模型的损失函数,计算所述第四图像翻译模型的第二损失;
具体的,本申请实施例中的第四图像翻译模型包括pix2pix,pix2pixHD和vid2vid中的至少一种,对于每一种图像具体的翻译模型,其对应的损失函数也有所不同:
对于pix2pix而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss,及使得输出多样化的GANLoss;
对于pix2pixHD而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、及Content Loss;
对于vid2vid而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、Content Loss、视频Loss和光流Loss;
对于每一种Loss函数在现有技术中都有详细描述,此处不再赘述。
得到第四图像翻译模型的损失函数后,根据具体的损失函数,计算第四图像翻译模型的第二损失。
803、根据所述第二损失和反向传播算法,对所述第四图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第五图像翻译模型。
得到第四图像翻译模型的第二损失后,根据第二损失和反向传播算法,对所述第四图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第五图像翻译模型。
具体的,根据第二损失和反向传播算法,对第四图像翻译模型中生成器的卷积层的权重进行梯度更新的过程,可以参阅步骤603的相关描述,此处不再赘述。
因为本申请实施例采用第三数据对第四图像翻译模型进行训练,得到第五图像翻译模型,使得第五图像翻译模型相比于第四图像翻译模型而言,在对图像进行推理时,所需要的输入数据量更少,即从原来的4N+1维数据,减少为4M+1维数据,其中,M为小于N的整数,从而提升了图像推理的速度,也即提升了对数字虚拟人姿态的仿真速度。
基于图8实施例所得到的第五图像翻译模型,为了进一步加快第五图像翻译模型的推理速度,还可以对第五图像翻译模型继续进行训练,下面对第五图像翻译模型的训练过程进行描述,请参阅图9,本申请实施例中模型训练方法的另一个实施例,包括:
901、利用所述训练集中的第四数据对所述第五图像翻译模型中的生成器进行训练,其中,所述第四数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据;
为了进一步提升第五图像翻译模型的图像推理速度,还可以利用训练集中的第四数据对所述第五图像翻译模型中的生成器进行训练,其中,所述第四数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据;
下面对第三数据和第四数据的区别进行说明:
假设第三数据中的M=1,则第三数据包括:目标帧图像及目标帧图像的前1帧图像、目标帧图像的轮廓线数据以及目标帧图像的前1帧图像的轮廓线数据;而第四数据包括:目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据。
也就是说,在利用第三数据推理目标帧的生成图像时,需要依赖目标帧图像的前1帧图像、目标帧图像的轮廓线数据以及目标帧图像的前1帧图像的轮廓线数据,而利用第四数据推理目标帧的生成图像时,只需要依赖目标帧图像的轮廓线数据、第一帧图像及第一帧图像的轮廓线数据。
由此可知,在利用第四数据推理目标帧的生成图像时,只要第一帧图像的信息固定,即可根据第一帧图像、第一帧图像的轮廓线数据和目标帧图像的轮廓线数据,得到目标帧的生成图像;而在利用第三数据推理目标帧的生成图像时,则只有在目标帧图像前M帧图像信息都固定后,才可以得到目标帧的生成图像,明显增加了目标帧图像前M帧图像的推理时间,故在利用第四数据推理目标帧的生成图像时,推理速度更快。
902、根据所述第五图像翻译模型的损失函数,计算所述第五图像翻译模型的第三损失;
具体的,本申请实施例中的第五图像翻译模型包括pix2pix,pix2pixHD和vid2vid中的至少一种,对于每一种图像具体的翻译模型,其对应的损失函数也有所不同:
对于pix2pix而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss,及使得输出多样化的GANLoss;
对于pix2pixHD而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、及Content Loss;
对于vid2vid而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、Content Loss、视频Loss和光流Loss;
对于每一种Loss函数在现有技术中都有详细描述,此处不再赘述。
得到第五图像翻译模型的损失函数后,根据具体的损失函数,计算第五图像翻译模型的第三损失。
903、根据所述第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
得到第五图像翻译模型的第三损失后,根据第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
具体的,根据第三损失和反向传播算法,对第五图像翻译模型中生成器的卷积层的权重进行梯度更新的过程,可以参阅步骤603的相关描述,此处不再赘述。
因为本申请实施例采用第四数据对第五图像翻译模型进行训练,得到第六图像翻译模型,使得第六图像翻译模型相比于第五图像翻译模型而言,在对图像进行推理时,所需要的推理时间更短,即从依赖目标帧前M帧图像信息,变更为只依赖第一帧图像的信息,从而提升了图像推理的速度,也即提升了对数字虚拟人姿态的仿真速度。
基于图9实施例所得到的第六图像翻译模型,为了进一步加快第六图像翻译模型的推理速度,还可以对第六图像翻译模型继续进行训练,下面对第六图像翻译模型的训练过程进行描述,请参阅图10,本申请实施例中模型训练方法的另一个实施例,包括:
1001、利用所述训练集中的第五数据对所述第六图像翻译模型中的生成器进行训练,其中,所述第五数据包括所述目标帧图像、所述目标帧图像的轮廓线数据,降低像素后的所述第一帧图像和所述第一帧图像的轮廓线数据;
为了进一步提升第六图像翻译模型的图像推理速度,还可以利用训练集中的第五数据对所述第六图像翻译模型中的生成器进行训练,其中,所述第五数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、降低像素后的第一帧图像和第一帧图像的轮廓线数据;
下面对第五数据和第四数据的区别进行说明:
本申请实施例中的第五数据包括:目标帧图像、所述目标帧图像的轮廓线数据、降低像素后的第一帧图像和第一帧图像的轮廓线数据;而第四数据包括:目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据。
也即第五数据中的第一帧图像为降低像素后的第一帧图像,而降低像素后的第一帧图像,相比于第一帧图像而言,参与训练的数据量更少,相应的利用第五数据,相比于第四数据而言,图像翻译模型的推理速度也更快。
1002、根据所述第六图像翻译模型的损失函数,计算所述第六图像翻译模型的第四损失;
具体的,本申请实施例中的第六图像翻译模型包括pix2pix,pix2pixHD和vid2vid中的至少一种,对于每一种图像具体的翻译模型,其对应的损失函数也有所不同:
对于pix2pix而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss,及使得输出多样化的GANLoss;
对于pix2pixHD而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、及Content Loss;
对于vid2vid而言,预设的损失函数包括目标帧图像与目标帧生成图像之间的L1Loss、使得输出多样化的GANLoss、Feature matching Loss、Content Loss、视频Loss和光流Loss;
对于每一种Loss函数在现有技术中都有详细描述,此处不再赘述。
得到第六图像翻译模型的损失函数后,根据具体的损失函数,计算第六图像翻译模型的第四损失。
1003、根据所述第四损失和反向传播算法,对所述第六图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第七图像翻译模型。
得到第五图像翻译模型的第三损失后,根据第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
具体的,根据第四损失和反向传播算法,对第六图像翻译模型中生成器的卷积层的权重进行梯度更新的过程,可以参阅步骤603的相关描述,此处不再赘述。
因为本申请实施例采用第五数据对第六图像翻译模型进行训练,得到第七图像翻译模型,使得第七图像翻译模型相比于第六图像翻译模型而言,在对图像进行推理时,所需要的推理时间更短,即参与推理的数据量更少,从而提升了图像推理的速度,也即提升了对数字虚拟人姿态的仿真速度。
上面对本申请实施例中的模型训练方法做了描述,下面接着对本申请实施例中的模型训练装置进行描述,请参阅图11,本申请实施例中模型训练装置的一个实施例,包括:
第一训练单元1101,用于利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器为编码模型-解码模型结构,所述编码模型采用的是残差网络架构,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像;
第一修改单元1102,用于将所述第一图像翻译模型的所述编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
优选的,所述装置还包括:
第二修改单元1103,用于将所述第二图像翻译模型的生成器中的编码模型首层中的大卷积算子修改为预设数量的小卷积算子,以得到第三图像翻译模型,其中,所述预设数量的小卷积算子和所述大卷子算子在输入相同的输入数据时,所述预设数量的小卷积算子对所述输入数据的数据运算量较小。
优选的,所述装置还包括:
第二训练单元1104,用于利用所述训练集中的第二数据对所述第三图像翻译模型中的生成器进行训练,其中,所述第二数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前N帧图像和所述前N帧图像的轮廓线数据;
第一计算单元1105,用于根据所述第三图像翻译模型的损失函数,计算所述第三图像翻译模型的第一损失;
第一更新单元1106,用于根据所述第一损失和反向传播算法,对所述第三图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第四图像翻译模型。
优选的,所述装置还包括:
第三训练单元1107,用于利用所述训练集中的第三数据对所述第四图像翻译模型中的生成器进行训练,其中,所述第三数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据,其中,所述M为大于等于1且小于N的整数;
第二计算单元1108,用于根据所述第四图像翻译模型的损失函数,计算所述第四图像翻译模型的第二损失;
第二更新单元1109,用于根据所述第二损失和反向传播算法,对所述第四图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第五图像翻译模型。
优选的,所述装置还包括:
第四训练单元1110,用于利用所述训练集中的第四数据对所述第五图像翻译模型中的生成器进行训练,其中,所述第四数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据;
第三计算单元1111,用于根据所述第五图像翻译模型的损失函数,计算所述第五图像翻译模型的第三损失;
第三更新单元1112,用于根据所述第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
优选的,所述装置还包括:
第五训练单元1113,用于利用所述训练集中的第五数据对所述第六图像翻译模型中的生成器进行训练,其中,所述第五数据包括所述目标帧图像、所述目标帧图像的轮廓线数据,降低像素后的所述第一帧图像和所述第一帧图像的轮廓线数据;
第四计算单元1114,用于根据所述第六图像翻译模型的损失函数,计算所述第六图像翻译模型的第四损失;
第四更新单元1115,用于根据所述第四损失和反向传播算法,对所述第六图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第七图像翻译模型。
优选的,所述轻量模型架构包括:
所述轻量模型架构包括MobileNet架构、ShuffleNet架构、SqueezeNet架构和Xception架构中的至少一种。
优选的,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。
需要说明的是,本申请实施例中各单元的作用与图1至图10实施例中描述的类似,此处不再赘述。
本申请实施例中,通过第一训练单元1101利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器采用的是编码-解码模型结构,所述编码模型采用的是残差网络架构,所述训练数集包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的任意一帧或任意多针图像;通过第一修改单元1102将所述第一图像翻译模型生成器中编码模型的残差网络架构修改为轻量模型架构,并将修改后的第一图像翻译模型作为第二图像翻译模型。
本申请实施例所得到的第二图像翻译模型,其生成器中的编码模型为轻量模型架构,因为轻量模型架构相比于残差网络架构而言,对同样输入数据的运算量明显减少,从而提升了图像翻译的速度,也即提升了对数字人姿态的仿真速度。
上面从模块化功能实体的角度对本发明实施例中的模型训练装置进行了描述,下面从硬件处理的角度对本发明实施例中的计算机装置进行描述:
该计算机装置用于实现模型训练装置的功能,本发明实施例中计算机装置一个实施例包括:
处理器以及存储器;
存储器用于存储计算机程序,处理器用于执行存储器中存储的计算机程序时,可以实现如下步骤:
利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器为编码模型-解码模型结构,所述编码模型采用的是残差网络架构,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像;
将所述第一图像翻译模型的所述编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
将所述第二图像翻译模型的生成器中的编码模型首层中的大卷积算子修改为预设数量的小卷积算子,以得到第三图像翻译模型,其中,所述预设数量的小卷积算子和所述大卷子算子在输入相同的输入数据时,所述预设数量的小卷积算子对所述输入数据的数据运算量较小。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
利用所述训练集中的第二数据对所述第三图像翻译模型中的生成器进行训练,其中,所述第二数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前N帧图像和所述前N帧图像的轮廓线数据;
根据所述第三图像翻译模型的损失函数,计算所述第三图像翻译模型的第一损失;
根据所述第一损失和反向传播算法,对所述第三图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第四图像翻译模型。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
利用所述训练集中的第三数据对所述第四图像翻译模型中的生成器进行训练,其中,所述第三数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据,其中,所述M为大于等于1且小于N的整数;
根据所述第四图像翻译模型的损失函数,计算所述第四图像翻译模型的第二损失;
根据所述第二损失和反向传播算法,对所述第四图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第五图像翻译模型。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
利用所述训练集中的第四数据对所述第五图像翻译模型中的生成器进行训练,其中,所述第四数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据;
根据所述第五图像翻译模型的损失函数,计算所述第五图像翻译模型的第三损失;
根据所述第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
利用所述训练集中的第五数据对所述第六图像翻译模型中的生成器进行训练,其中,所述第五数据包括所述目标帧图像、所述目标帧图像的轮廓线数据,降低像素后的所述第一帧图像和所述第一帧图像的轮廓线数据;
根据所述第六图像翻译模型的损失函数,计算所述第六图像翻译模型的第四损失;
根据所述第四损失和反向传播算法,对所述第六图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第七图像翻译模型。
可以理解的是,上述说明的计算机装置中的处理器执行所述计算机程序时,也可以实现上述对应的各装置实施例中各单元的功能,此处不再赘述。示例性的,所述计算机程序可以被分割成一个或多个模块/单元,所述一个或者多个模块/单元被存储在所述存储器中,并由所述处理器执行,以完成本发明。所述一个或多个模块/单元可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述所述计算机程序在所述模型训练装置中的执行过程。例如,所述计算机程序可以被分割成上述模型训练装置中的各单元,各单元可以实现如上述相应模型训练装置说明的具体功能。
所述计算机装置可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述计算机装置可包括但不仅限于处理器、存储器。本领域技术人员可以理解,处理器、存储器仅仅是计算机装置的示例,并不构成对计算机装置的限定,可以包括更多或更少的部件,或者组合某些部件,或者不同的部件,例如所述计算机装置还可以包括输入输出设备、网络接入设备、总线等。
所述处理器可以是中央处理单元(Central Processing Unit,CPU),还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable GateArray,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等,所述处理器是所述计算机装置的控制中心,利用各种接口和线路连接整个计算机装置的各个部分。
所述存储器可用于存储所述计算机程序和/或模块,所述处理器通过运行或执行存储在所述存储器内的计算机程序和/或模块,以及调用存储在存储器内的数据,实现所述计算机装置的各种功能。所述存储器可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序等;存储数据区可存储根据终端的使用所创建的数据等。此外,存储器可以包括高速随机存取存储器,还可以包括非易失性存储器,例如硬盘、内存、插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(SecureDigital,SD)卡,闪存卡(Flash Card)、至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
本发明还提供了一种计算机可读存储介质,该计算机可读存储介质用于实现模型训练装置的功能,其上存储有计算机程序,计算机程序被处理器执行时,处理器,可以用于执行如下步骤:
利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器为编码模型-解码模型结构,所述编码模型采用的是残差网络架构,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像;
将所述第一图像翻译模型的所述编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
将所述第二图像翻译模型的生成器中的编码模型首层中的大卷积算子修改为预设数量的小卷积算子,以得到第三图像翻译模型,其中,所述预设数量的小卷积算子和所述大卷子算子在输入相同的输入数据时,所述预设数量的小卷积算子对所述输入数据的数据运算量较小。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
利用所述训练集中的第二数据对所述第三图像翻译模型中的生成器进行训练,其中,所述第二数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前N帧图像和所述前N帧图像的轮廓线数据;
根据所述第三图像翻译模型的损失函数,计算所述第三图像翻译模型的第一损失;
根据所述第一损失和反向传播算法,对所述第三图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第四图像翻译模型。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
利用所述训练集中的第三数据对所述第四图像翻译模型中的生成器进行训练,其中,所述第三数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据,其中,所述M为大于等于1且小于N的整数;
根据所述第四图像翻译模型的损失函数,计算所述第四图像翻译模型的第二损失;
根据所述第二损失和反向传播算法,对所述第四图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第五图像翻译模型。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
利用所述训练集中的第四数据对所述第五图像翻译模型中的生成器进行训练,其中,所述第四数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据;
根据所述第五图像翻译模型的损失函数,计算所述第五图像翻译模型的第三损失;
根据所述第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
利用所述训练集中的第五数据对所述第六图像翻译模型中的生成器进行训练,其中,所述第五数据包括所述目标帧图像、所述目标帧图像的轮廓线数据,降低像素后的所述第一帧图像和所述第一帧图像的轮廓线数据;
根据所述第六图像翻译模型的损失函数,计算所述第六图像翻译模型的第四损失;
根据所述第四损失和反向传播算法,对所述第六图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第七图像翻译模型。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
以上所述,以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (11)
1.一种模型训练方法,其特征在于,所述方法包括:
利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器为编码模型-解码模型结构,所述编码模型采用的是残差网络架构,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像;
将所述第一图像翻译模型的所述编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
2.根据权利要求1所述的方法,其特征在于,所述方法还包括:
将所述第二图像翻译模型的生成器中的编码模型首层中的大卷积算子修改为预设数量的小卷积算子,以得到第三图像翻译模型,其中,所述预设数量的小卷积算子和所述大卷积算子在输入相同的输入数据时,所述预设数量的小卷积算子对所述输入数据的数据运算量较小。
3.根据权利要求2所述的方法,其特征在于,所述方法还包括:
利用所述训练集中的第二数据对所述第三图像翻译模型中的生成器进行训练,其中,所述第二数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前N帧图像和所述前N帧图像的轮廓线数据;
根据所述第三图像翻译模型的损失函数,计算所述第三图像翻译模型的第一损失;
根据所述第一损失和反向传播算法,对所述第三图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第四图像翻译模型。
4.根据权利要求3所述的方法,其特征在于,所述方法还包括:
利用所述训练集中的第三数据对所述第四图像翻译模型中的生成器进行训练,其中,所述第三数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据,其中,所述M为大于等于1且小于N的整数;
根据所述第四图像翻译模型的损失函数,计算所述第四图像翻译模型的第二损失;
根据所述第二损失和反向传播算法,对所述第四图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第五图像翻译模型。
5.根据权利要求4所述的方法,其特征在于,所述方法还包括:
利用所述训练集中的第四数据对所述第五图像翻译模型中的生成器进行训练,其中,所述第四数据包括所述目标帧图像、所述目标帧图像的轮廓线数据、所述第一帧图像和所述第一帧图像的轮廓线数据;
根据所述第五图像翻译模型的损失函数,计算所述第五图像翻译模型的第三损失;
根据所述第三损失和反向传播算法,对所述第五图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第六图像翻译模型。
6.根据权利要求5所述的方法,其特征在于,所述方法还包括:
利用所述训练集中的第五数据对所述第六图像翻译模型中的生成器进行训练,其中,所述第五数据包括所述目标帧图像、所述目标帧图像的轮廓线数据,降低像素后的所述第一帧图像和所述第一帧图像的轮廓线数据;
根据所述第六图像翻译模型的损失函数,计算所述第六图像翻译模型的第四损失;
根据所述第四损失和反向传播算法,对所述第六图像翻译模型中生成器的卷积层的权重进行梯度更新,以得到第七图像翻译模型。
7.根据权利要求1至6中任一项所述的方法,其特征在于,所述轻量模型架构包括:
所述轻量模型架构包括MobileNet架构、ShuffleNet架构、SqueezeNet架构和Xception架构中的至少一种。
8.根据权利要求7所述的方法,其特征在于,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。
9.一种模型训练装置,其特征在于,所述装置包括:
第一训练单元,用于利用训练集中的第一数据对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为第一图像翻译模型,其中,所述生成器为编码模型-解码模型结构,所述编码模型采用的是残差网络架构,所述第一数据包括目标帧图像、所述目标帧图像的轮廓线数据、所述目标帧图像的距离图像数据、所述目标帧图像的前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,其中,所述N为大于等于2的整数,所述目标帧图像为所述训练集中除了第一帧图像和第二帧图像以外的任意一帧或任意多帧图像;
第一修改单元,用于将所述第一图像翻译模型的所述编码模型的残差网络架构修改为轻量模型架构,以得到第二图像翻译模型。
10.一种计算机装置,包括处理器,其特征在于,所述处理器在执行存储于存储器上的计算机程序时,用于实现如权利要求1至8中任一项所述的模型训练方法。
11.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时,用于实现如权利要求1至8中任一项所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110496339.5A CN113112400B (zh) | 2021-05-07 | 2021-05-07 | 一种模型训练方法及模型训练装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110496339.5A CN113112400B (zh) | 2021-05-07 | 2021-05-07 | 一种模型训练方法及模型训练装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113112400A CN113112400A (zh) | 2021-07-13 |
CN113112400B true CN113112400B (zh) | 2024-04-09 |
Family
ID=76721200
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110496339.5A Active CN113112400B (zh) | 2021-05-07 | 2021-05-07 | 一种模型训练方法及模型训练装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113112400B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116563246B (zh) * | 2023-05-10 | 2024-01-30 | 之江实验室 | 一种用于医学影像辅助诊断的训练样本生成方法及装置 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111091493A (zh) * | 2019-12-24 | 2020-05-01 | 北京达佳互联信息技术有限公司 | 图像翻译模型训练方法、图像翻译方法及装置和电子设备 |
CN111222560A (zh) * | 2019-12-30 | 2020-06-02 | 深圳大学 | 一种图像处理模型生成方法、智能终端及存储介质 |
CN111833238A (zh) * | 2020-06-01 | 2020-10-27 | 北京百度网讯科技有限公司 | 图像的翻译方法和装置、图像翻译模型的训练方法和装置 |
CN111860485A (zh) * | 2020-07-24 | 2020-10-30 | 腾讯科技(深圳)有限公司 | 图像识别模型的训练方法、图像的识别方法、装置、设备 |
CN112287779A (zh) * | 2020-10-19 | 2021-01-29 | 华南农业大学 | 一种低光照度图像自然光照度补强方法及应用 |
CN112488243A (zh) * | 2020-12-18 | 2021-03-12 | 北京享云智汇科技有限公司 | 一种图像翻译方法 |
-
2021
- 2021-05-07 CN CN202110496339.5A patent/CN113112400B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111091493A (zh) * | 2019-12-24 | 2020-05-01 | 北京达佳互联信息技术有限公司 | 图像翻译模型训练方法、图像翻译方法及装置和电子设备 |
CN111222560A (zh) * | 2019-12-30 | 2020-06-02 | 深圳大学 | 一种图像处理模型生成方法、智能终端及存储介质 |
CN111833238A (zh) * | 2020-06-01 | 2020-10-27 | 北京百度网讯科技有限公司 | 图像的翻译方法和装置、图像翻译模型的训练方法和装置 |
CN111860485A (zh) * | 2020-07-24 | 2020-10-30 | 腾讯科技(深圳)有限公司 | 图像识别模型的训练方法、图像的识别方法、装置、设备 |
CN112287779A (zh) * | 2020-10-19 | 2021-01-29 | 华南农业大学 | 一种低光照度图像自然光照度补强方法及应用 |
CN112488243A (zh) * | 2020-12-18 | 2021-03-12 | 北京享云智汇科技有限公司 | 一种图像翻译方法 |
Also Published As
Publication number | Publication date |
---|---|
CN113112400A (zh) | 2021-07-13 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11307864B2 (en) | Data processing apparatus and method | |
US11307865B2 (en) | Data processing apparatus and method | |
CN107609641B (zh) | 稀疏神经网络架构及其实现方法 | |
KR102258414B1 (ko) | 처리 장치 및 처리 방법 | |
KR102142889B1 (ko) | 스파스 연결용 인공 신경망 계산 장치와 방법 | |
KR102486030B1 (ko) | 완전연결층 신경망 정방향 연산 실행용 장치와 방법 | |
US10169084B2 (en) | Deep learning via dynamic root solvers | |
EP3637272A1 (en) | Data sharing system and data sharing method therefor | |
CN107066239A (zh) | 一种实现卷积神经网络前向计算的硬件结构 | |
CN107340993B (zh) | 运算装置和方法 | |
US20210295168A1 (en) | Gradient compression for distributed training | |
CN107239824A (zh) | 用于实现稀疏卷积神经网络加速器的装置和方法 | |
CN107886167A (zh) | 神经网络运算装置及方法 | |
CN107341547A (zh) | 一种用于执行卷积神经网络训练的装置和方法 | |
CN112840356A (zh) | 运算加速器、处理方法及相关设备 | |
CN111353591B (zh) | 一种计算装置及相关产品 | |
CN117094374A (zh) | 电子电路及内存映射器 | |
CN113160041B (zh) | 一种模型训练方法及模型训练装置 | |
CN116468114A (zh) | 一种联邦学习方法及相关装置 | |
CN113112400B (zh) | 一种模型训练方法及模型训练装置 | |
CN116187430A (zh) | 一种联邦学习方法及相关装置 | |
CN109359542A (zh) | 基于神经网络的车辆损伤级别的确定方法及终端设备 | |
US20240144000A1 (en) | Fairness-based neural network model training using real and generated data | |
CN113077383B (zh) | 一种模型训练方法及模型训练装置 | |
WO2022127603A1 (zh) | 一种模型处理方法及相关装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |