CN112884640A - 模型训练方法及相关装置、可读存储介质 - Google Patents
模型训练方法及相关装置、可读存储介质 Download PDFInfo
- Publication number
- CN112884640A CN112884640A CN202110224930.5A CN202110224930A CN112884640A CN 112884640 A CN112884640 A CN 112884640A CN 202110224930 A CN202110224930 A CN 202110224930A CN 112884640 A CN112884640 A CN 112884640A
- Authority
- CN
- China
- Prior art keywords
- model
- generation
- current frame
- original
- generation model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 212
- 238000000034 method Methods 0.000 title claims abstract description 139
- 238000013519 translation Methods 0.000 claims abstract description 113
- 238000004422 calculation algorithm Methods 0.000 claims abstract description 50
- 230000008569 process Effects 0.000 claims description 84
- 230000006870 function Effects 0.000 claims description 65
- 238000004590 computer program Methods 0.000 claims description 29
- 238000013528 artificial neural network Methods 0.000 description 14
- 238000013527 convolutional neural network Methods 0.000 description 13
- 239000011159 matrix material Substances 0.000 description 11
- 238000010586 diagram Methods 0.000 description 9
- 238000004364 calculation method Methods 0.000 description 8
- 239000013598 vector Substances 0.000 description 7
- 230000004913 activation Effects 0.000 description 6
- 238000011176 pooling Methods 0.000 description 5
- 230000009467 reduction Effects 0.000 description 5
- 238000012545 processing Methods 0.000 description 4
- 238000006243 chemical reaction Methods 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 210000002569 neuron Anatomy 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000013256 Gubra-Amylin NASH model Methods 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 230000001419 dependent effect Effects 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 230000005540 biological transmission Effects 0.000 description 1
- 238000009795 derivation Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 230000000638 stimulation Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T3/00—Geometric image transformations in the plane of the image
- G06T3/04—Context-preserving transformations, e.g. by using an importance map
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Computational Linguistics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Mathematical Physics (AREA)
- General Engineering & Computer Science (AREA)
- Molecular Biology (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例公开了一种模型训练方法及模型训练装置,用于提升图像翻译模型的推理速度。本发明实施例方法包括:采用第一数据和第二数据对原始图像翻译模型中的原始生成模型进行训练,以得到当前帧的第一生成图片,第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,第二数据包括当前帧的原始图片和前两帧的原始图片;对原始生成模型执行fine‑tune微调操作,得到一代生成模型,微调操作包括根据预设的损失函数,计算当前帧的原始图片与当前帧的第一生成图片的第一损失,根据第一损失及反向传播算法,对原始生成模型中卷积层的权重进行梯度更新,一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,一代生成模型的图像生成质量不大于预设的FID值。
Description
技术领域
本发明涉及图像翻译技术领域,尤其涉及模型训练方法及相关装置、可读存储介质。
背景技术
所谓图像翻译,指从一副图像到另一副图像的转换。可以类比机器翻译,将一种语言转换为另一种语言。
现有技术中较为经典的图像翻译模型有pix2pix,pix2pixHD,vid2vid。pix2pix提出了一个统一的框架解决了各类图像翻译问题,pix2pixHD则在pix2pix的基础上,较好的解决了高分辨率图像转换(翻译)的问题,vid2vid则在pix2pixHD的基础上,较好的解决了高分辨率的视频转换问题。
但目前的vid2vid模型,如Nvidia的vid2vid中的头部姿态翻译模型,在实际训练过程中,因为其采用的GAN模型数据计算量大,如目前的头部姿态翻译模型需要输入第一部分数据和第二部分数据,其中,第一部分数据包括当前帧和前两帧的轮廓线,以及当前帧和前两帧的distanceMap数据,进一步,每一帧的轮廓线为1维数据,则当前帧和前两帧的轮廓线共3维数据,而每一帧的distanceMap包括4维数据,则当前帧和前两帧的distanceMap共包括12维数据,且上述15(12+3)维数据全部参与模型的训练;第二部分数据包括当前帧和前两帧的原始图片,其中,当前帧为生成图像的Label,前两帧作为训练的输入,且每个图片的大小要求为X×X×3,故第二部分中参与训练的数据为6(2×3)维,故头部姿态翻译模型在训练过程中的训练数据总计为15+6,即21维。
这样,头部姿态翻译模型在实际应用中因为采用训练的数据量大,会出现翻译速度慢、翻译生成的视频帧不连贯、不稳定以及翻译实时性较差的问题。
发明内容
本发明实施例提供了一种模型训练方法及模型训练装置,用于提升图像模型翻译模型的推理速度。
本申请实施例第一方面提供了一种模型训练方法,包括:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
优选的,所述方法还包括:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
优选的,所述方法还包括:
采用第三数据和第四数据对所述一代生成模型进行训练,以得到当前帧的第二生成图片,所述第三数据包括当前帧的轮廓线数据和前一帧的轮廓线数据,所述第四数据包括当前帧的原始图片和前一帧的原始图片;
对所述一代生成模型执行fine-tune微调操作,直至得到二代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第二生成图片之间的第二损失,根据所述第二损失及所述反向传播算法,对所述一代生成模型中卷积层的权重进行梯度更新,其中,所述二代生成模型为二代图像翻译模型中的GAN网络中的生成模型,且所述二代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述方法还包括:
在对所述一代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述一代生成模型中卷积层的学习率。
优选的,所述方法还包括:
采用第五数据和第六数据对所述二代生成模型进行训练,以得到当前帧的第三生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第六数据包括当前帧的原始图片和第一帧的原始图片;
对所述二代生成模型执行fine-tune微调操作,直至得到三代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第三生成图片之间的第三损失,根据所述第三损失及所述反向传播算法,对所述二代生成模型中卷积层的权重进行梯度更新,其中,所述三代生成模型为三代图像翻译模型中的GAN网络中的生成模型,且所述三代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述方法还包括:
在对所述二代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述二代生成模型中卷积层的学习率。
优选的,所述方法还包括:
采用所述第五数据和第七数据对所述三代生成模型进行训练,以得到当前帧的第四生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第七数据包括当前帧的原始图片,及降低像素后的第一帧的图片;
对所述三代生成模型执行fine-tune微调操作,直至得到四代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第四生成图片之间的第四损失,根据所述第四损失及所述反向传播算法,对所述三代生成模型中卷积层的权重进行梯度更新,其中,所述四代生成模型为四代图像翻译模型中的GAN网络中的生成模型,且所述四代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述图像翻译模型包括头部姿态翻译模型、身体姿态翻译模型和街景图像翻译模型中的任一种。
本申请实施例还提供了一种模型训练方法,包括:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的生成图片,所述第一数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第二数据包括当前帧的原始图片及降低像素后的第一帧的图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图像与所述当前帧生成图片的损失,根据所述损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
优选的,所述方法还包括:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
本申请实施例第三方面提供了一种模型训练装置,包括:
第一训练单元,用于采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
第二训练单元,用于对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
优选的,所述模型训练装置还包括:
第三训练单元,用于在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
优选的,所述模型训练装置还包括:
第四训练单元,用于采用第三数据和第四数据对所述一代生成模型进行训练,以得到当前帧的第二生成图片,所述第三数据包括当前帧的轮廓线数据和前一帧的轮廓线数据,所述第四数据包括当前帧的原始图片和前一帧的原始图片;
第五训练单元,用于对所述一代生成模型执行fine-tune微调操作,直至得到二代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第二生成图片之间的第二损失,根据所述第二损失及所述反向传播算法,对所述一代生成模型中卷积层的权重进行梯度更新,其中,所述二代生成模型为二代图像翻译模型中的GAN网络中的生成模型,且所述二代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述模型训练单元还包括:
第六训练单元,用于在对所述一代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述一代生成模型中卷积层的学习率。
优选的,所述模型训练单元还包括:
第七训练单元,用于采用第五数据和第六数据对所述二代生成模型进行训练,以得到当前帧的第三生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第六数据包括当前帧的原始图片和第一帧的原始图片;
第八训练单元,用于对所述二代生成模型执行fine-tune微调操作,直至得到三代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第三生成图片之间的第三损失,根据所述第三损失及所述反向传播算法,对所述二代生成模型中卷积层的权重进行梯度更新,其中,所述三代生成模型为三代图像翻译模型中的GAN网络中的生成模型,且所述三代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述模型训练单元还包括:
第九训练单元,用于在对所述二代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述二代生成模型中卷积层的学习率。
优选的,所述模型训练单元还包括:
第十训练单元,用于采用所述第五数据和第七数据对所述三代生成模型进行训练,以得到当前帧的第四生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第七数据包括当前帧的原始图片,及降低像素后的第一帧的图片;
第十一训练单元,用于对所述三代生成模型执行fine-tune微调操作,直至得到四代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第四生成图片之间的第四损失,根据所述第四损失及所述反向传播算法,对所述三代生成模型中卷积层的权重进行梯度更新,其中,所述四代生成模型为四代图像翻译模型中的GAN网络中的生成模型,且所述四代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述图像翻译模型包括头部姿态翻译模型、身体姿态翻译模型和街景图像翻译模型中的任一种。
本申请实施例第四方面还提供了一种模型训练装置,包括:
第一训练单元,用于采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的生成图片,所述第一数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第二数据包括当前帧的原始图片及降低像素后的第一帧的图片;
第二训练单元,用于对所述原始生成模型执行fine-tune微调操作,直至生成一代生成模型,所述微调操作包括根据预设的损失函数计算所述当前帧的原始图像与所述当前帧生成图片的损失,根据所述损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
优选的,该模型训练装置还包括:
第三训练单元,用于在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
本申请实施例第五方面提供了一种计算机装置,包括处理器,该处理器在执行存储于存储器上的计算机程序时,用于实现本申请实施例第一方面或第二方面提供的模型训练方法。
本申请实施例第六方面提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时,用于实现本申请实施例第一方面或第二方面提供的模型训练方法。
从以上技术方案可以看出,本发明实施例具有以下优点:
本申请实施例中,采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧和前两帧的轮廓线数据,所述第二数据包括当前帧和前两帧的原始图片;对原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述原始生成模型为所述原始图像翻译模型中的GAN网络中的生成模型,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
因为本申请实施例中在对原始生成模型进行训练时,所采用的数据相较于现有技术而言,减少了当前帧和前两帧的distanceMap数据,使得一代生成模型对当前帧的推理生成速度加快,从而保证了图像翻译模型中视频帧的连贯性、稳定性及实时性。
附图说明
图1为本申请实施例中模型训练方法的一个实施例示意图;
图2为本申请实施例中神经网络结构的示意图;
图3为本申请实施例中模型训练方法的另一个实施例示意图;
图4为本申请实施例中模型训练方法的另一个实施例示意图;
图5为本申请实施例中模型训练方法的另一个实施例示意图;
图6为本申请实施例中模型训练方法的另一个实施例示意图;
图7为本申请实施例中模型训练方法的另一个实施例示意图;
图8为本申请实施例中模型训练装置的一个实施例示意图;
图9为本申请实施例中模型训练装置的另一个实施例示意图。
具体实施方式
本发明实施例提供了一种模型训练方法及装置,用于提升模型推理的速度,以保证图像翻译模型中图像帧生成的连续性、稳定性及实时性。
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的实施例能够以除了在这里图示或描述的内容以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
基于现有技术中,图像翻译模型在训练过程中,要采用两部分共21维的数据,从而导致图像翻译模型在训练过程中产生的图像推理速度慢的问题,本申请提出了一种模型训练方法及模型训练装置,用于提升图像翻译模型的图像推理速度。
具体的,现有技术中较为经典的图像翻译模型有pix2pix,pix2pixHD,vid2vid,但目前的vid2vid模型,因为其采用的GAN模型数据计算量大,如目前的头部姿态翻译模型需要输入第一部分数据和第二部分数据,其中,第一部分数据包括当前帧和前两帧的轮廓线,以及当前帧和前两帧的distanceMap数据,进一步,每一帧的轮廓线为1维数据,则当前帧和前两帧的轮廓线共3维数据,而每一帧的distanceMap包括4维数据,则当前帧和前两帧的distanceMap共包括12维数据,且上述15(12+3)维数据全部参与模型的训练;第二部分数据包括当前帧和前两帧的原始图片,其中,当前帧为生成图像的Label,前两帧作为训练的输入,且每个图片的大小要求为X×X×3,故第二部分中参与训练的数据为6(2×3)维,故头部姿态翻译模型在训练过程中的训练数据总计为15+6,即21维。
本申请中的模型训练方法,通过减少训练数据的数量,用于提升图像帧的生成速度,也即提升图像帧的推理速度,为方便理解,下面对本申请中的模型训练方法做详细描述,请参阅图1,本申请中模型训练方法的一个实施例,包括:
101、采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
具体的,本申请中的图像翻译模型包括vid2vid模型中的头部姿态翻译模型、身体姿态翻译模式和街景图像翻译模型,其中,关于上述三种模型的具体内容在Nvidia的官方网站上都有详细描述,此处不再赘述。
现有的图像翻译模型采用的都是GAN网络模型,其中,GAN网络应用到深度学习神经网络上来说,就是通过生成模型G(Generator)和判别模型D(Discriminator)不断博弈,进而使G学习到数据的分布,如果用到图片生成上,则训练完成后,G可以从一段随机数中生成逼真的图像。
G、D的主要功能是:
G是一个生成式的网络,它接收一个随机的噪声z(随机数),通过这个噪声生成图像;
D是一个判别网络,判别一张图片是不是“真实的”;它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
训练过程中,生成模型G的目标就是尽量生成真实的图片去欺骗判别模型D,而D的目标就是尽量辨别出G生成的是假图像还是真实的图像。这样,G和D构成了一个动态的博弈过程,最终的平衡点即纳什均衡点,若G和D之间达到纳什均衡点,则结束对G的训练。
本申请中通过改变GAN网络中的生成模型的训练数据,以达到提升生成模型推理速度的目的。
具体的,本申请中采用采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片,其中,当前帧的原始图片作为当前帧生成图片的Label,以用于计算当前帧真实图片与当前帧生成图片之间的损失,而前两帧的原始图片则作为模型的训练输入,对模型进行训练。
由此可知,本申请中在对原始生成模型进行训练时,采用的数据包括:当前帧的轮廓线数据和前两帧的轮廓线数据,以及前两帧的原始图片,其中,每一帧的轮廓线为1维数据,则当前帧和前两帧的轮廓线共3维数据,每一帧的原始图片的大小为X*X*3,其中,X为图片的像素大小,而3为图片通道的个数,其中3通道代表图片为RGB图像,则前两帧的原始图片的数据维数为6(2*3)维,故本申请在对原始生成模型进行训练时,一起包括9(3+6)维的数据,相比于现有技术中的21维数据而言,大大减少了训练数据的输入量,从而使得图像翻译模型在训练过程中的推理速度加快。
具体的,根据实验证实,本申请通过减少训练数据的输入维数,可以将原始生成模型的推理速度从1帧/1s,提升到25帧/s,使得图像翻译模型的推理速度提升了25倍。
102、对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
在步骤101中,将第一数据和第二数据输入到原始生成模型后,对原始生成模型执行fine-tune微调操作,直至生成一代生成模型,其中,该一代生成模型为一代图像翻译模型中的生成模型,且一代生成模型的图像生成质量不大于预设的FID值。
具体的,迁移学习作为一种机器学习思想,应用到深度学习就是微调(Fine-tune)。微调能够快速训练好一个模型,用相对较小的数据量,达到预期的训练结果。
进一步,本申请中的微调操作包括根据预设的损失函数,计算当前帧的原始图片与当前帧的第一生成图片之间的第一损失,然后根据第一损失和返向传播算法,对原始生成模型中卷积层的权重进行梯度更新。
下面具体说明:在目前的图像翻译模型中,采用的损失函数包括5个,第一个图像分布LOSS,用于保障生成图像的真实度;第二个是视频LOSS,用于保障生成视频的连贯性;第三个是光流LOSS,用于保障估算光流的正确性;第四个是Feature matching LOSS,第五个是Content LOSS,其中,第四个和第五个损失函数,是将当前帧的生成图像和当前帧的原始图片分别放到判别模型和VGG 16提取特征,利用特征图计算Element-wise Loss,从而保证内容一致,提升训练稳固性。
其中,具体的每个损失函数在现有技术中也都有详细描述,此处不再赘述。
为方便理解梯度更新的过程,先对GAN网络中的生成模型做简单描述:
图像翻译模型中的生成模型采用的是神经网络算法,而多层感知器(Multi-LayerPerceptron,MLP)也叫人工神经网络(Artificial Neural Network,ANN),一般包括输入层、输出层,及设于输入层和输出层之间的多个隐层。最简单的MLP需要有一层隐层,即输入层、隐层和输出层才能称为一个简单的神经网络。
接下来以图2中的神经网络为例,对数据的传导过程进行描述:
1、神经网络的前向输出
其中,第0层(输入层),我们将x1、x2、x3向量化为X;
0层和1层(隐层)之间,存在权重w1、w2、w3,将权重向量化为W[1],其中W[1]表示第一层的权重;
0层和1层(隐层)之间,还存在偏置b1、b2、b3,将其向量化为b[1],其中b[1]表示第一层的权重;
对于第1层,计算公式为:
Z[1]=W[1]X+b[1];
A[1]=sigmoid(Z[1]);
其中,Z为输入值的线型组合,A为Z通过激活函数sigmoid的值,对于第一层的输入值X,输出值为A,也是下一层的输入值,在sigmoid激活函数中,其取值在[0,1]之间,可以将其理解为一个阀,就像人的神经元一样,当一个神经元受到刺激,并不是立刻感觉到,而是这个刺激超声了阀值,才会让神经元向上级传播。
1层和2层(输出层)之间,与0层和1层之间类似,其计算公式如下:
Z[2]=W[2]X+b[2]
A[2]=sigmoid(Z[2])
yhat=A[2];
其中yhat即为本次神经网络的输出值。
2、损失函数
在神经网络训练的过程中,一般通过损失函数来衡量这个神经网络是否训练到位。
一般情况下,我们选择如下函数作为作为损失函数:
3、反向传播算法
在上述神经网络模型中,可以通过计算损失函数来得到神经网络的训练效果,同时还可以通过反向传播算法,来更新参数,使得神经网络模型更能得到我们想要的预测值。其中,梯度下降算法即为一种优化权重W和偏置b的方法。
具体的,梯度下降算法是对损失函数求偏导数,然后用偏导数来更新w1、w2和b。
z=w1x1+w2x2+b;
然后,分别将对α和z求导:
再对w1、w2和b求导:
接下来用梯度下降算法更新权重参数w和偏置参数b:
其中,w1:=w1-βdw1
w2:=w2-βdw2
b:=b-βdb;
其中,β表示学习率,即学习的步长,在实际训练过程中,如果学习率过大,则会在最优解附近来回震荡,无法达到最优解,如果学习率过小,则可能需要很多次的迭代,才能达到最优解,所以在实际训练过程中,学习率也是一个很重要的选择参数。
但在图像识别技术领域,卷积神经网络(CNN)是一种深度前馈人工神经网络,主要包括卷积层、池化层和全连接层。
其中,卷积神经网络和普通的神经网络前向传播过程类似,CNN中的卷积层也是先由输入和权重做线性运算,然后将结果送进一个激活函数,得到最终的输出,不同点在于卷积神经网络中,权重是卷积核(filter),输入是多维度的像素矩阵,二者进行卷积运算。
而CNN中的池化层则是在卷积层之后,用来对输入进行降采样,降低维度并保持显著的参数,其特点是输出一个固定大小的矩阵,并且降低输出结果的维度,一般池化层中没有需要学习的参数。
全连接层(fully connected layers,FC)在整个卷积神经网络中起到“分类器”的作用。如果说卷积层、池化层和激活函数层等操作是将原始数据映射到隐层特征空间的话,全连接层则起到将学到的“分布式特征表示”映射到样本标记空间的作用。在CNN中,全连接层出现在后几层,用于对前面设计的特征做加权和。在实际使用中,全连接层可由卷积操作实现:对前层是卷积层的全连接层可以转化为卷积核为h*w的全局卷积,h和w分别为前层卷积结果的高和宽。
如果说在CNN中,卷积取的是局部特征的话,全连接就是把以前的局部特征重新通过权值矩阵组成完整的图,一般全连接层中也没有需要学习的参数。
而本申请中的fine-tune微调操作即是计算当前帧的原始图片与当前帧的第一生成图片之间的第一损失,然后利用第一损失和反向传播算法,对卷积神经网络中卷积层的权重,也即卷积核逐步更新的过程,其中,根据损失函数计算第一损失的过程,及具体对卷积核进行更新的过程可以参阅上述损失函数和反向传播算法中的计算过程,此处不再赘述。
进一步,在对原始生成模型执行训练,得到一代生成模型的过程中,为了评价一代生成模型的质量,可以通过FID值对一代生成模型所生成图像的质量进行评估。
具体的,FID(FréchetInceptionDistance)是用来计算真实图像与生成图像的特征向量间距离的一种度量,用这个距离来衡量真实图像和生成图像的相似程度,如果FID值越小,则相似程度越高。最好情况即是FID=0,表示两个图像相同。
本申请实施例中设置一代生成模型的图像生成质量不大于预设的FID值,其中,具体的FID值可以根据实际需求进行设置,此处不做限制。
下面具体对FID的计算过程进行描述:
假设图像的真实分布Pr和生成分布Pg建模为多维高斯分布,参数分别为(μr,Σr)和(μg,Σg),其中μ和Σ分别为均值向量和协方差矩阵。
FID的计算公式为:
其中,Tr表示矩阵的迹(矩阵对角元之和)。
在实际计算FID的过程中,一般假设特征向量维数为n,那么均值向量u的维数为n,协方差矩阵∑的维数为n*n。那么在计算过程中,首先分别选取真实图像和生成图像各N张,计算得到的特征向量为N*n维,之后分别计算这N个样本对应的均值向量u和协方差矩阵∑,即可得到真实分布Pr和生成分布Pg对应的参数。
需要说明的是,本申请实施例中的身体姿态翻译模型,所采用的第一数据,即当前帧和前两帧的轮廓线数据,为身体的骨骼点的连线所组成轮廓线,而非人体外部的轮廓线。
因为本申请实施例中在对原始生成模型进行训练时,相较于现有技术而言,减少了输入数据量,使得对当前帧的推理生成速度变快,从而保证了图像翻译模型中视频帧的实时性、连贯性以及稳定性。
具体的,在头部姿态翻译模型和街景图像翻译模型中,一代生成模型的训练过程中,减少了当前帧和前两帧的distanceMap数据,而每一帧的distancemap数据包括4维数据,则当前帧和前两帧的distancemap数据共包括12维数据,使得一代生成模型的训练数据从21维降为9维,从而提升了当前帧图像的推理生成速度。
基于图1所述的实施例,下面接着对模型训练方法进行描述,在实际训练过程中,为了使得一代生成模型在对卷积层的权重做梯度更新时,能够得到权重的最优解,以提升一代生成模型的质量,本申请实施例还可以执行以下步骤,具体请参阅图3:
301、在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
根据图1实施例中步骤102的描述可知,在对权重做梯度更新的过程中,β表示学习率,即学习的步长,在实际训练过程中,如果学习率过大,则会在最优解附近来回震荡,无法达到最优解。
故本申请在对原始生成模型中的卷积层的权重做梯度更新的过程中,减少了对原始生成模型的学习率,也即减小了β的取值,优选的,本申请实施例中将β值由原来的0.005,减小为0.0005。
本申请实施例中,为了得到更优化的一代生成模型,在对原始生成模型的训练过程中,减少了对原始生成模型的学习率,也即减少了β的取值,使得一代生成模型所生成图片,更接近于真实图片。
基于图3实施例所生成的一代生成模型,为了加快一代生成模型的推理速度,还可以对一代生成模型继续进行训练,得到二代生成模型,下面接着对二代生成模型的训练过程,进行描述,请参阅图4,本申请中模型训练方法的另一个实施例,包括:
401、采用第三数据和第四数据对所述一代生成模型进行训练,以得到当前帧的第二生成图片,所述第三数据包括当前帧的轮廓线数据和前一帧的轮廓线数据,所述第四数据包括当前帧的原始图片和前一帧的原始图片;
在实际训练过程中,为了加快一代生成模型的推理速度,本申请实施例还可以采用第三数据和第四数据对一代生成模型进行训练,以得到当前帧的第二生成图片,其中,第三数据包括当前帧和前一帧的轮廓线数据,第四数据包括当前帧和前一帧的原始图片。
因为一代生成模型的训练过程相较于原始生成模型的训练过程而言,在得到当前帧的生成图片时,由依赖于前两帧的信息,变为依赖前一帧的信息,也即将输入数据从9维降低为5维,故一代生成模型训练过程中所采用的数据量,相较于原始生成模型训练过程中所采用的数据量减少,提升了模型的运算速度。
402、对所述一代生成模型执行fine-tune微调操作,直至得到二代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第二生成图片之间的第二损失,根据所述第二损失及所述反向传播算法,对所述一代生成模型中卷积层的权重进行梯度更新,其中,所述二代生成模型为二代图像翻译模型中的GAN网络中的生成模型,且所述二代生成模型的图像生成质量不大于所述预设的FID值;
步骤401中,在采用第三数据和第四数据训练一代生成模型的过程中,对一代生成模型执行fine-tune微调操作,直至得到二代生成模型,其中,fine-tune微调操作具体包括:利用当前帧的原始图片和上述第二生成图片计算第二损失,并根据第二损失和反向传播算法,对一代生成模型中卷积层的权重进行梯度更新。
具体的,对一代生成模型执行fine-tune操作的过程,可以参阅图1实施例中步骤102所述,此处不再赘述。
进一步,为了评估二代生成模型中生成图像的质量,还可以设置FID值,使得二代生成模型的图像质量不大于预设的FID值,具体的FID值的描述也与步骤102中描述的类似,此处也不再赘述。
403、在对所述一代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述一代生成模型中卷积层的学习率。
与图3所述的实施例类似,为了得到更优化的二代生成模型,在对一代生成模型的训练过程中,减少了对一代生成模型的学习率,也即减少了β的取值,使得二代生成模型所生成图片,更接近于真实图片。
本申请实施例中,二代生成模型相较于一代生成模型而言,在生成当前帧的图片时,由依赖前两帧图片的信息,变为依赖前一帧的信息,也即一代生成模型在生成当前帧的图片时,需要9维数据,而二代生成模型在生成当前帧的图片时,只需要5维数据,从而减少了运算的数据量,提升了对当前帧图像的推理生成速度。
具体的,一代生成模型在生成当前帧的图片时,是根据前两帧图片的原始图片+前两帧图片的轮廓线数据+当前帧的轮廓线数据,推理得到当前帧的生成图片,而二代生成图片在生成当前帧的图片时,是根据前一帧图片的原始图片+前一帧图片的轮廓线数据+当前帧的轮廓线数据,推理得到当前帧的生成图片,所以二代生成模型的推理速度,相较于一代生成模型的推理速度更快。
基于图4所述的实施例,为了更进一步提升二代生成模型的推理速度,本申请还可以执行以下步骤,具体请参阅图5,本申请实施例中模型训练方法的另一个实施例,包括:
501、采用第五数据和第六数据对所述二代生成模型进行训练,以得到当前帧的第三生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第六数据包括当前帧的原始图片和第一帧的原始图片;
在实际训练过程中,为了加快二代生成模型的推理速度,本申请实施例还可以采用第五数据和第六数据对二代生成模型进行训练,以得到当前帧的第三生成图片,其中,第五数据包括当前帧和第一帧的轮廓线数据,第四数据包括当前帧和第一帧的原始图片。
因为一代生成模型在训练过程中,所采用的数据为当前帧的轮廓线数据和前一帧的轮廓线数据,以及当前帧和前一帧的原始图片,而二代生成模型在训练过程中,所采用的数据为当前帧的轮廓线数据和第一帧的轮廓线数据,以及当前帧和第一帧的原始图片。
也就是说,一代生成模型在训练过程中,当前帧的生成图片需要依赖前一帧图像信息,而二代生成模型在训练过程中,当前帧的生成图片只需要依赖第一帧图片信息,在图像推理过程中,二代生成模型在训练过程中,只要第一帧的图像信息固定,即可生成当前帧的图像,而一代生成模型则只有在当前帧的前一帧图像固定后,才可以生成当前当前帧的图像,明显增加了前一帧图像的推理时间。
故二代生成模型训练过程中生成当前帧图像的时间,相较于一代生成模型生成当前帧图像的时间而言,需要的时间更少,也即对当前帧图像的推理速度更快。
502、对所述二代生成模型执行fine-tune微调操作,直至得到三代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第三生成图片之间的第三损失,根据所述第三损失及所述反向传播算法,对所述二代生成模型中卷积层的权重进行梯度更新,其中,所述三代生成模型为三代图像翻译模型中的GAN网络中的生成模型,且所述三代生成模型的图像生成质量不大于所述预设的FID值;
步骤501中,在采用第五数据和第六数据训练二代生成模型的过程中,对二代生成模型执行fine-tune微调操作,直至得到三代生成模型,其中,fine-tune微调操作具体包括:利用当前帧的原始图片和上述第三生成图片计算第三损失,并根据第三损失和反向传播算法,对二代生成模型中卷积层的权重进行梯度更新。
具体的,对二代生成模型执行fine-tune操作的过程,也可以参阅图1实施例中步骤102所述,此处不再赘述。
进一步,为了评估三代生成模型中生成图像的质量,也可以设置FID值,使得三代生成模型的图像质量不大于预设的FID值,具体的FID值的描述也与步骤102中描述的类似,此处也不再赘述。
503、在对所述二代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述二代生成模型中卷积层的学习率。
与图3所述的实施例类似,为了得到更优化的三代生成模型,在对二代生成模型的训练过程中,减少了对二代生成模型的学习率,也即减少了β的取值,使得三代生成模型所生成图片,更接近于真实图片。
本申请实施例中,三代生成模型相较于一代生成模型而言,在生成当前帧的图片时,由依赖前一帧图片的信息,变为依赖第一帧的信息,也即二代生成模型在生成当前帧的图片时,是根据前一帧图片的原始图片+前一帧图片的轮廓线数据+当前帧的轮廓线数据,而三代生成模型在生成当前帧的图片时,是根据第一帧图片的原始图片+第一帧图片的轮廓线数据+当前帧的轮廓线数据,明显减少了推理前一帧图片的时间,故三代生成模型想相较于二代生成模型而言,推理速度更快。
基于图5所述的实施例,为了更进一步提升三代生成模型的推理速度,本申请还可以执行以下步骤,具体请参阅图6,本申请实施例中模型训练方法的另一个实施例,包括:
601、采用所述第五数据和第七数据对所述三代生成模型进行训练,以得到当前帧的第四生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第七数据包括当前帧的原始图片,及降低像素后的第一帧的图片;
在实际训练过程中,为了加快三代生成模型的推理速度,本申请实施例还可以采用第五数据和第七数据对三代生成模型进行训练,以得到当前帧的第四生成图片,其中,第五数据包括当前帧和第一帧的轮廓线数据,第七数据包括当前帧的原始图片,及降低像素后的第一帧的原始图片。
因为在图像识别技术领域,卷积神经网络(CNN)是一种深度前馈人工神经网络,主要包括卷积层、池化层和全连接层。
其中,卷积神经网络和普通的神经网络前向传播过程类似,CNN中的卷积层也是先由输入和权重做线性运算,然后将结果送进一个激活函数,得到最终的输出,不同点在于卷积神经网络中,权重是卷积核(filter),输入是多维度的像素矩阵,二者进行卷积运算。
因为卷积层中的输入是多维度的像素矩阵和卷积核,使其做卷积运算,而像素矩阵中像素值的大小,直接影响运算的速度,很明显在对矩阵执行卷积运算过程中,矩阵元素的数值越大,也即像素值越大,则运算越复杂,得到运算结果的时间就越长,故本申请实施例通过降低像素值的方法,来提升运算速度,达到提升图像推理速度的目的。
需要说明的是,因为第七数据中当前帧的原始图片用于计算第四损失,故不能对当前帧原始图片的像素做处理,而只能对第一帧原始图片的像素做了降低像素的处理,使得在保证图像生成质量的前提下,提升图像的推理速度。
602、对所述三代生成模型执行fine-tune微调操作,直至得到四代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第四生成图片之间的第四损失,根据所述第四损失及所述反向传播算法,对所述三代生成模型中卷积层的权重进行梯度更新,其中,所述四代生成模型为四代图像翻译模型中的GAN网络中的生成模型,且所述四代生成模型的图像生成质量不大于所述预设的FID值。
具体的,对三代生成模型执行fine-tune操作的过程,可以参阅图1实施例中步骤102所述,此处不再赘述。
进一步,为了评估四代生成模型中生成图像的质量,也可以设置FID值,使得四代生成模型的图像质量不大于预设的FID值,具体的FID值的描述也与步骤102中描述的类似,此处也不再赘述。
本申请实施例中,通过降低第一帧原始图片的像素,来减少卷积层中像素数据的运算量,从而提升四代生成模型对当前帧图片的推理速度。
下面对模型训练方法的另一个实施例进行描述,请参阅图7,本申请实施例中模型训练方法的另一个实施例,包括:
701、采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的生成图片,所述第一数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第二数据包括当前帧的原始图片及降低像素后的第一帧的图片;
区别于图1至图6所述的实施例,本申请实施例在训练原始图像翻译模型中的GAN网络中的原始生成模型时,直接采用第一数据和第二数据,其中,第一数据包括当前帧轮廓线数据和第一帧的轮廓线数据,第二数据包括当前帧的原始图片及降低像素后的第一帧的图片。
这样,使得对原始生成模型执行训练时,所生成的当前帧图片,只依赖第一帧的图像信息,即第一帧的轮廓线数据和降低像素后的第一帧的图片,从而降低了训练过程中的输入数据量,提升了模型对当前帧图片的推理速度。
具体的,现有技术对原始模型执行训练时,需要输入21维的数据,其中包括当前帧和前两帧的轮廓线(共3维数据),及当前帧和前两帧的distanceMap数据(共12维数据),以及当前帧和前两帧的原始图片(共6维数据),其中,当前帧的原始图片为生成图像的label,不参加训练,且每个图片的大小为X*X*3。
而本申请实施例在对原始生成模型执行训练时,只需要输入5维的数据,具体包括当前帧和第一帧的轮廓线数据(共3维数据),当前帧的原始图片及降低像素后的第一帧的图片(共2维),其中,当前帧的原始图片为生成图像的label,不参加训练。
且在对原始生成模型执行训练的过程中,原始生成模型在生成当前帧图像时,只需要依赖第一帧的图像信息,即可快速推理出当前帧的图像,而现有技术在训练过程中,在生成当前帧的图像时,则要依赖当前帧的前一帧的图像,故本申请实施例相比于现有技术而言,在生成当前帧图像时,减少了当前帧的前一帧的推理时间,提升了当前帧图像的推理速度。
进一步,为了节省运算量,本申请实施例还采用降低像素后的第一帧的图像,从而减少了CNN网络中卷积层的运算数据量,进一步提升了当前帧图像的推理速度。
702、对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图像与所述当前帧生成图片的损失,根据所述损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
具体的,对原始生成模型执行fine-tune操作的过程,可以参阅图1实施例中步骤102所述,此处不再赘述。
进一步,为了一代生成模型中生成图像的质量,也可以设置FID值,使得四代生成模型的图像质量不大于预设的FID值,具体的FID值的描述也与步骤102中描述的类似,此处也不再赘述。
703、在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
与图3所述的实施例类似,为了得到更优化的一代生成模型,在对原始生成模型的训练过程中,减少了对原始生成模型的学习率,也即减少了β的取值,使得一代生成模型所生成图片,更接近于真实图片。
本申请实施例直接采用当前帧的轮廓线数据和第一帧的轮廓线数据,以及当前帧的原始图片及降低像素后的第一帧的图片,对原始生成模型执行训练,以得到一代生成模型,故一代生成模型较原始生成模型而言,参与训练的数据量更少,相应的的一代生成模型的推理速度也更快。
上面对本申请实施例中的模型训练方法做了描述,下面接着对本申请实施例中的模型训练装置进行描述,请参阅图8,本申请实施例中模型训练装置的一个实施例,包括:
第一训练单元801,用于采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
第二训练单元802,用于对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
优选的,所述模型训练装置还包括:
第三训练单元803,用于在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
优选的,所述模型训练装置还包括:
第四训练单元804,用于采用第三数据和第四数据对所述一代生成模型进行训练,以得到当前帧的第二生成图片,所述第三数据包括当前帧的轮廓线数据和前一帧的轮廓线数据,所述第四数据包括当前帧的原始图片和前一帧的原始图片;
第五训练单元805,用于对所述一代生成模型执行fine-tune微调操作,直至得到二代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第二生成图片之间的第二损失,根据所述第二损失及所述反向传播算法,对所述一代生成模型中卷积层的权重进行梯度更新,其中,所述二代生成模型为二代图像翻译模型中的GAN网络中的生成模型,且所述二代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述模型训练单元还包括:
第六训练单元806,用于在对所述一代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述一代生成模型中卷积层的学习率。
优选的,所述模型训练单元还包括:
第七训练单元807,用于采用第五数据和第六数据对所述二代生成模型进行训练,以得到当前帧的第三生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第六数据包括当前帧的原始图片和第一帧的原始图片;
第八训练单元808,用于对所述二代生成模型执行fine-tune微调操作,直至得到三代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第三生成图片之间的第三损失,根据所述第三损失及所述反向传播算法,对所述二代生成模型中卷积层的权重进行梯度更新,其中,所述三代生成模型为三代图像翻译模型中的GAN网络中的生成模型,且所述三代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述模型训练单元还包括:
第九训练单元809,用于在对所述二代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述二代生成模型中卷积层的学习率。
优选的,所述模型训练单元还包括:
第十训练单元810,用于采用所述第五数据和第七数据对所述三代生成模型进行训练,以得到当前帧的第四生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第七数据包括当前帧的原始图片,及降低像素后的第一帧的图片;
第十一训练单元811,用于对所述三代生成模型执行fine-tune微调操作,直至得到四代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第四生成图片之间的第四损失,根据所述第四损失及所述反向传播算法,对所述三代生成模型中卷积层的权重进行梯度更新,其中,所述四代生成模型为四代图像翻译模型中的GAN网络中的生成模型,且所述四代生成模型的图像生成质量不大于所述预设的FID值。
优选的,所述图像翻译模型包括头部姿态翻译模型、身体姿态翻译模型和街景图像翻译模型中的任一种。
因为本申请实施例通过第一训练单元801和第二训练单元802,在对原始生成模型进行训练时,所采用的数据量减少(从21维减少为9维),故使得原始生成模型对当前帧的推理生成速度变快,从而保证了图像翻译模型中视频帧的实时性、连贯性以及稳定性。
进一步,本申请又实施例通过第四训练单元804,将训练数据从9维降低为5维,进一步减少了运算的数据量,提升了模型的推理速度。
进一步,本申请实施例又通过第七训练单元807,将生成当前帧图像时所依赖的前一帧的图像,变更为依赖第一帧的图像,从而减少了在生成当前帧时,所依赖的前一帧图像的生成时间,提升了模型的推理速度。
进一步,本申请实施例又通过第十训练单元810,将生成当前帧图像所依赖的第一帧的图像信息,做降低像素的处理,从而进一步减少了在生成当前帧的图像时,卷积神经网络中卷积层的运算量,提升了模型的推理速度。
下面接着对模型训练装置的另一个实施例进行描述,请参阅图9,本申请实施例中模型训练装置的另一个实施例,包括:
第一训练单元901,用于采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的生成图片,所述第一数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第二数据包括当前帧的原始图片及降低像素后的第一帧的图片;
第二训练单元902,用于对所述原始生成模型执行fine-tune微调操作,直至生成一代生成模型,所述微调操作包括根据预设的损失函数计算所述当前帧的原始图像与所述当前帧生成图片的损失,根据所述损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
优选的,该模型训练装置还包括:
第三训练单元903,用于在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
因为本申请实施例通过第一训练单元901,直接采用当前帧的轮廓线数据和第一帧的轮廓线数据,以及当前帧的原始图片及降低像素后的第一帧的图片,对原始生成模型执行训练,以得到一代生成模型,故一代生成模型较原始生成模型而言,参与训练的数据量更少(从21维变为5维),故得到的一代生成模型的推理速度也更快。
上面从模块化功能实体的角度对本发明实施例中的模型训练装置进行了描述,下面从硬件处理的角度对本发明实施例中的计算机装置进行描述:
该计算机装置用于实现模型训练装置的功能,本发明实施例中计算机装置一个实施例包括:
处理器以及存储器;
存储器用于存储计算机程序,处理器用于执行存储器中存储的计算机程序时,可以实现如下步骤:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
采用第三数据和第四数据对所述一代生成模型进行训练,以得到当前帧的第二生成图片,所述第三数据包括当前帧的轮廓线数据和前一帧的轮廓线数据,所述第四数据包括当前帧的原始图片和前一帧的原始图片;
对所述一代生成模型执行fine-tune微调操作,直至得到二代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第二生成图片之间的第二损失,根据所述第二损失及所述反向传播算法,对所述一代生成模型中卷积层的权重进行梯度更新,其中,所述二代生成模型为二代图像翻译模型中的GAN网络中的生成模型,且所述二代生成模型的图像生成质量不大于所述预设的FID值。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
在对所述一代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述一代生成模型中卷积层的学习率。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
采用第五数据和第六数据对所述二代生成模型进行训练,以得到当前帧的第三生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第六数据包括当前帧的原始图片和第一帧的原始图片;
对所述二代生成模型执行fine-tune微调操作,直至得到三代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第三生成图片之间的第三损失,根据所述第三损失及所述反向传播算法,对所述二代生成模型中卷积层的权重进行梯度更新,其中,所述三代生成模型为三代图像翻译模型中的GAN网络中的生成模型,且所述三代生成模型的图像生成质量不大于所述预设的FID值。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
在对所述二代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述二代生成模型中卷积层的学习率。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
采用所述第五数据和第七数据对所述三代生成模型进行训练,以得到当前帧的第四生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第七数据包括当前帧的原始图片,及降低像素后的第一帧的图片;
对所述三代生成模型执行fine-tune微调操作,直至得到四代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第四生成图片之间的第四损失,根据所述第四损失及所述反向传播算法,对所述三代生成模型中卷积层的权重进行梯度更新,其中,所述四代生成模型为四代图像翻译模型中的GAN网络中的生成模型,且所述四代生成模型的图像生成质量不大于所述预设的FID值。
本申请实施例还提供了一种计算机装置,该计算机装置用于实现模型训练装置的功能,本发明实施例中计算机装置的另一个实施例包括:
处理器以及存储器;
存储器用于存储计算机程序,处理器用于执行存储器中存储的计算机程序时,可以实现如下步骤:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的生成图片,所述第一数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第二数据包括当前帧的原始图片及降低像素后的第一帧的图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图像与所述当前帧生成图片的损失,根据所述损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
在本发明的一些实施例中,处理器,还可以用于实现如下步骤:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
可以理解的是,上述说明的计算机装置中的处理器执行所述计算机程序时,也可以实现上述对应的各装置实施例中各单元的功能,此处不再赘述。示例性的,所述计算机程序可以被分割成一个或多个模块/单元,所述一个或者多个模块/单元被存储在所述存储器中,并由所述处理器执行,以完成本发明。所述一个或多个模块/单元可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述所述计算机程序在所述模型训练装置中的执行过程。例如,所述计算机程序可以被分割成上述模型训练装置中的各单元,各单元可以实现如上述相应模型训练装置说明的具体功能。
所述计算机装置可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述计算机装置可包括但不仅限于处理器、存储器。本领域技术人员可以理解,处理器、存储器仅仅是计算机装置的示例,并不构成对计算机装置的限定,可以包括更多或更少的部件,或者组合某些部件,或者不同的部件,例如所述计算机装置还可以包括输入输出设备、网络接入设备、总线等。
所述处理器可以是中央处理单元(Central Processing Unit,CPU),还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable GateArray,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等,所述处理器是所述计算机装置的控制中心,利用各种接口和线路连接整个计算机装置的各个部分。
所述存储器可用于存储所述计算机程序和/或模块,所述处理器通过运行或执行存储在所述存储器内的计算机程序和/或模块,以及调用存储在存储器内的数据,实现所述计算机装置的各种功能。所述存储器可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序等;存储数据区可存储根据终端的使用所创建的数据等。此外,存储器可以包括高速随机存取存储器,还可以包括非易失性存储器,例如硬盘、内存、插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(SecureDigital,SD)卡,闪存卡(Flash Card)、至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
本发明还提供了一种计算机可读存储介质,该计算机可读存储介质用于实现模型训练装置的功能,其上存储有计算机程序,计算机程序被处理器执行时,处理器,可以用于执行如下步骤:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
采用第三数据和第四数据对所述一代生成模型进行训练,以得到当前帧的第二生成图片,所述第三数据包括当前帧的轮廓线数据和前一帧的轮廓线数据,所述第四数据包括当前帧的原始图片和前一帧的原始图片;
对所述一代生成模型执行fine-tune微调操作,直至得到二代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第二生成图片之间的第二损失,根据所述第二损失及所述反向传播算法,对所述一代生成模型中卷积层的权重进行梯度更新,其中,所述二代生成模型为二代图像翻译模型中的GAN网络中的生成模型,且所述二代生成模型的图像生成质量不大于所述预设的FID值。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
在对所述一代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述一代生成模型中卷积层的学习率。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
采用第五数据和第六数据对所述二代生成模型进行训练,以得到当前帧的第三生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第六数据包括当前帧的原始图片和第一帧的原始图片;
对所述二代生成模型执行fine-tune微调操作,直至得到三代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第三生成图片之间的第三损失,根据所述第三损失及所述反向传播算法,对所述二代生成模型中卷积层的权重进行梯度更新,其中,所述三代生成模型为三代图像翻译模型中的GAN网络中的生成模型,且所述三代生成模型的图像生成质量不大于所述预设的FID值。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
在对所述二代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述二代生成模型中卷积层的学习率。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
采用所述第五数据和第七数据对所述三代生成模型进行训练,以得到当前帧的第四生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第七数据包括当前帧的原始图片,及降低像素后的第一帧的图片;
对所述三代生成模型执行fine-tune微调操作,直至得到四代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第四生成图片之间的第四损失,根据所述第四损失及所述反向传播算法,对所述三代生成模型中卷积层的权重进行梯度更新,其中,所述四代生成模型为四代图像翻译模型中的GAN网络中的生成模型,且所述四代生成模型的图像生成质量不大于所述预设的FID值。
本申请还提供了另一种计算机可读存储介质,该计算机可读存储介质也用于实现模型训练装置的功能,其上存储有计算机程序,计算机程序被处理器执行时,处理器,可以用于执行如下步骤:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的生成图片,所述第一数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第二数据包括当前帧的原始图片及降低像素后的第一帧的图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图像与所述当前帧生成图片的损失,根据所述损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
在本发明的一些实施例中,计算机可读存储介质存储的计算机程序被处理器执行时,处理器,可以具体用于执行如下步骤:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
可以理解的是,所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在相应的一个计算机可读取存储介质中。基于这样的理解,本发明实现上述相应的实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(ROM,Read-OnlyMemory)、随机存取存储器(RAM,Random Access Memory)、电载波信号、电信信号以及软件分发介质等。需要说明的是,所述计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统,装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
以上所述,以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (14)
1.一种模型训练方法,其特征在于,所述方法包括:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
2.根据权利要求1所述的模型训练方法,其特征在于,所述方法还包括:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
3.根据权利要求2所述的方法,其特征在于,所述方法还包括:
采用第三数据和第四数据对所述一代生成模型进行训练,以得到当前帧的第二生成图片,所述第三数据包括当前帧的轮廓线数据和前一帧的轮廓线数据,所述第四数据包括当前帧的原始图片和前一帧的原始图片;
对所述一代生成模型执行fine-tune微调操作,直至得到二代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第二生成图片之间的第二损失,根据所述第二损失及所述反向传播算法,对所述一代生成模型中卷积层的权重进行梯度更新,其中,所述二代生成模型为二代图像翻译模型中的GAN网络中的生成模型,且所述二代生成模型的图像生成质量不大于所述预设的FID值。
4.根据权利要求3所述的方法,其特征在于,所述方法还包括:
在对所述一代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述一代生成模型中卷积层的学习率。
5.根据权利要求4所述的方法,其特征在于,所述方法还包括:
采用第五数据和第六数据对所述二代生成模型进行训练,以得到当前帧的第三生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第六数据包括当前帧的原始图片和第一帧的原始图片;
对所述二代生成模型执行fine-tune微调操作,直至得到三代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第三生成图片之间的第三损失,根据所述第三损失及所述反向传播算法,对所述二代生成模型中卷积层的权重进行梯度更新,其中,所述三代生成模型为三代图像翻译模型中的GAN网络中的生成模型,且所述三代生成模型的图像生成质量不大于所述预设的FID值。
6.根据权利要求5所述的方法,其特征在于,所述方法还包括:
在对所述二代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述二代生成模型中卷积层的学习率。
7.根据权利要求6所述的方法,其特征在于,所述方法还包括:
采用所述第五数据和第七数据对所述三代生成模型进行训练,以得到当前帧的第四生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第七数据包括当前帧的原始图片,及降低像素后的第一帧的图片;
对所述三代生成模型执行fine-tune微调操作,直至得到四代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第四生成图片之间的第四损失,根据所述第四损失及所述反向传播算法,对所述三代生成模型中卷积层的权重进行梯度更新,其中,所述四代生成模型为四代图像翻译模型中的GAN网络中的生成模型,且所述四代生成模型的图像生成质量不大于所述预设的FID值。
8.根据权利要求1至7中任一项所述的模型训练方法,其特征在于,所述图像翻译模型包括头部姿态翻译模型、身体姿态翻译模型和街景图像翻译模型中的任一种。
9.一种模型训练方法,其特征在于,所述方法包括:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的生成图片,所述第一数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第二数据包括当前帧的原始图片及降低像素后的第一帧的图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图像与所述当前帧生成图片的损失,根据所述损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
10.根据权利要求10所述的方法,其特征在于,所述方法还包括:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。
11.一种模型训练装置,其特征在于,所述装置包括:
第一训练单元,用于采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
第二训练单元,用于对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
12.一种模型训练装置,其特征在于,所述装置包括:
第一训练单元,用于采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的生成图片,所述第一数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第二数据包括当前帧的原始图片及降低像素后的第一帧的图片;
第二训练单元,用于对所述原始生成模型执行fine-tune微调操作,直至生成一代生成模型,所述微调操作包括根据预设的损失函数计算所述当前帧的原始图像与所述当前帧生成图片的损失,根据所述损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。
13.一种计算机装置,包括处理器,其特征在于,所述处理器在执行存储于存储器上的计算机程序时,用于实现如权利要求1至8中任一项,或者权利要求9至10中任一项所述的模型训练方法。
14.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时,用于实现如权利要求1至8中任一项,或者权利要求9至10中任一项所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110224930.5A CN112884640B (zh) | 2021-03-01 | 2021-03-01 | 模型训练方法及相关装置、可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110224930.5A CN112884640B (zh) | 2021-03-01 | 2021-03-01 | 模型训练方法及相关装置、可读存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112884640A true CN112884640A (zh) | 2021-06-01 |
CN112884640B CN112884640B (zh) | 2024-04-09 |
Family
ID=76054989
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110224930.5A Active CN112884640B (zh) | 2021-03-01 | 2021-03-01 | 模型训练方法及相关装置、可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112884640B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113077383A (zh) * | 2021-06-07 | 2021-07-06 | 深圳追一科技有限公司 | 一种模型训练方法及模型训练装置 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110826688A (zh) * | 2019-09-23 | 2020-02-21 | 江苏艾佳家居用品有限公司 | 一种保障gan模型最大最小损失函数平稳收敛的训练方法 |
CN111625608A (zh) * | 2020-04-20 | 2020-09-04 | 中国地质大学(武汉) | 一种基于gan模型根据遥感影像生成电子地图的方法、系统 |
CN112164008A (zh) * | 2020-09-29 | 2021-01-01 | 中国科学院深圳先进技术研究院 | 图像数据增强网络的训练方法及其训练装置、介质和设备 |
US20210042503A1 (en) * | 2018-11-14 | 2021-02-11 | Nvidia Corporation | Generative adversarial neural network assisted video compression and broadcast |
-
2021
- 2021-03-01 CN CN202110224930.5A patent/CN112884640B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210042503A1 (en) * | 2018-11-14 | 2021-02-11 | Nvidia Corporation | Generative adversarial neural network assisted video compression and broadcast |
CN110826688A (zh) * | 2019-09-23 | 2020-02-21 | 江苏艾佳家居用品有限公司 | 一种保障gan模型最大最小损失函数平稳收敛的训练方法 |
CN111625608A (zh) * | 2020-04-20 | 2020-09-04 | 中国地质大学(武汉) | 一种基于gan模型根据遥感影像生成电子地图的方法、系统 |
CN112164008A (zh) * | 2020-09-29 | 2021-01-01 | 中国科学院深圳先进技术研究院 | 图像数据增强网络的训练方法及其训练装置、介质和设备 |
Non-Patent Citations (1)
Title |
---|
黄菲;高飞;朱静洁;戴玲娜;俞俊;: "基于生成对抗网络的异质人脸图像合成:进展与挑战", 南京信息工程大学学报(自然科学版), no. 06, 28 November 2019 (2019-11-28) * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113077383A (zh) * | 2021-06-07 | 2021-07-06 | 深圳追一科技有限公司 | 一种模型训练方法及模型训练装置 |
CN113077383B (zh) * | 2021-06-07 | 2021-11-02 | 深圳追一科技有限公司 | 一种模型训练方法及模型训练装置 |
Also Published As
Publication number | Publication date |
---|---|
CN112884640B (zh) | 2024-04-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109949255B (zh) | 图像重建方法及设备 | |
Iqbal et al. | Generative adversarial network for medical images (MI-GAN) | |
Oyedotun et al. | Deep learning in vision-based static hand gesture recognition | |
CN108510012B (zh) | 一种基于多尺度特征图的目标快速检测方法 | |
Liang et al. | MCFNet: Multi-layer concatenation fusion network for medical images fusion | |
WO2022036777A1 (zh) | 基于卷积神经网络的人体动作姿态智能估计方法及装置 | |
WO2021043168A1 (zh) | 行人再识别网络的训练方法、行人再识别方法和装置 | |
An et al. | Medical image segmentation algorithm based on feedback mechanism CNN | |
CN112668366B (zh) | 图像识别方法、装置、计算机可读存储介质及芯片 | |
CN113256592B (zh) | 图像特征提取模型的训练方法、系统及装置 | |
CN113989890A (zh) | 基于多通道融合和轻量级神经网络的人脸表情识别方法 | |
CN114863225B (zh) | 图像处理模型训练方法、生成方法、装置、设备及介质 | |
CN112560639B (zh) | 人脸关键点数目转换方法、系统、电子设备及存储介质 | |
Shahamatnia et al. | Application of particle swarm optimization and snake model hybrid on medical imaging | |
CN113205017A (zh) | 跨年龄人脸识别方法及设备 | |
CN109961397B (zh) | 图像重建方法及设备 | |
Gao et al. | Integrated GANs: Semi-supervised SAR target recognition | |
Ninh et al. | Skin lesion segmentation based on modification of SegNet neural networks | |
CN110570394A (zh) | 医学图像分割方法、装置、设备及存储介质 | |
Qiao et al. | A pseudo-siamese feature fusion generative adversarial network for synthesizing high-quality fetal four-chamber views | |
Angelopoulou et al. | Fast 2d/3d object representation with growing neural gas | |
CN113763535A (zh) | 一种特征潜码提取方法、计算机设备及存储介质 | |
CN112884640A (zh) | 模型训练方法及相关装置、可读存储介质 | |
CN113850796A (zh) | 基于ct数据的肺部疾病识别方法及装置、介质和电子设备 | |
CN113158970A (zh) | 一种基于快慢双流图卷积神经网络的动作识别方法与系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |