CN112052948B - 一种网络模型压缩方法、装置、存储介质和电子设备 - Google Patents
一种网络模型压缩方法、装置、存储介质和电子设备 Download PDFInfo
- Publication number
- CN112052948B CN112052948B CN202010837744.4A CN202010837744A CN112052948B CN 112052948 B CN112052948 B CN 112052948B CN 202010837744 A CN202010837744 A CN 202010837744A CN 112052948 B CN112052948 B CN 112052948B
- Authority
- CN
- China
- Prior art keywords
- network model
- countermeasure network
- loss
- generator
- output result
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 75
- 230000006835 compression Effects 0.000 title claims abstract description 55
- 238000007906 compression Methods 0.000 title claims abstract description 55
- 238000003860 storage Methods 0.000 title claims abstract description 19
- 238000012549 training Methods 0.000 claims abstract description 48
- 238000012545 processing Methods 0.000 claims abstract description 15
- 230000006870 function Effects 0.000 claims description 92
- 238000013140 knowledge distillation Methods 0.000 claims description 46
- 230000003044 adaptive effect Effects 0.000 claims description 15
- 238000013519 translation Methods 0.000 claims description 9
- 238000004590 computer program Methods 0.000 claims description 5
- 238000013473 artificial intelligence Methods 0.000 abstract description 9
- 238000013135 deep learning Methods 0.000 abstract description 3
- 238000013256 Gubra-Amylin NASH model Methods 0.000 description 48
- 238000005516 engineering process Methods 0.000 description 12
- 239000008186 active pharmaceutical agent Substances 0.000 description 8
- 230000000694 effects Effects 0.000 description 7
- 238000010801 machine learning Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 5
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 4
- 238000013459 approach Methods 0.000 description 4
- 238000005457 optimization Methods 0.000 description 4
- 241000282326 Felis catus Species 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000001514 detection method Methods 0.000 description 3
- 230000008485 antagonism Effects 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000013136 deep learning model Methods 0.000 description 2
- 230000005251 gamma ray Effects 0.000 description 2
- 230000010354 integration Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 241001465754 Metazoa Species 0.000 description 1
- 230000001133 acceleration Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000003190 augmentative effect Effects 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000007599 discharging Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000006698 induction Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Data Exchanges In Wide-Area Networks (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请实施例公开了一种网络模型压缩方法、装置、存储介质和电子设备,所述方法涉及人工智能领域中的深度学习方向,包括:获取训练完成的第一生成式对抗网络模型,初始化第二生成式对抗网络模型,将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果,基于第一输出结果和第二输出结果,生成交叉判别损失,基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。该方案可以获取到有效保存第一生成式对抗网络模型能力,并且大大减少网络模型参数量的第二生成式对抗网络模型。
Description
技术领域
本申请涉及计算机技术领域,具体涉及一种网络模型压缩方法、装置、存储介质和电子设备。
背景技术
生成式对抗网络模型是一种深度学习模型,生成式对抗网络模型通过框架中生成模型和判别模型之间的互相博弈学习产生相当好的输出,因此,生成式对抗网络模型在图像/视频翻译、文本/图像/视频生成等任务上都有重要的应用。但是目前生成式对抗网络模型的模型结构都比较复杂,并且包含大量的参数,不利于实际的应用。
发明内容
本申请实施例提供一种网络模型压缩方法、装置、存储介质和电子设备,该方案可以获取到有效保存第一生成式对抗网络模型能力,并且大大减少网络模型参数量的第二生成式对抗网络模型。
本申请实施例提供一种网络模型压缩方法,包括:
获取训练完成的第一生成式对抗网络模型,所述第一生成式对抗网络模型包括第一生成器和第一判别器;
初始化第二生成式对抗网络模型,所述第二生成式对抗网络模型与所述第一生成式对抗网络模型针对相同的模型任务,所述第二生成式对抗网络模型的网络模型参数量小于所述第一生成式对抗网络模型的网络模型参数量,且所述第二生成式对抗网络模型包括第二生成器和第二判别器;
将训练数据分别输入至所述第一生成式对抗网络模型、和所述第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果;
基于所述第一输出结果和所述第二输出结果,生成交叉判别损失,所述交叉判别损失为所述第一生成式对抗网络模型中第一生成器和第一判别器、与所述第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失;
基于所述交叉判别损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
相应的,本申请实施例还提供一种网络模型压缩装置,包括:
获取模块,用于获取训练完成的第一生成式对抗网络模型,所述第一生成式对抗网络模型包括第一生成器和第一判别器;
初始化模块,用于初始化第二生成式对抗网络模型,所述第二生成式对抗网络模型与所述第一生成式对抗网络模型针对相同的模型任务,所述第二生成式对抗网络模型的网络模型参数量小于所述第一生成式对抗网络模型的网络模型参数量,且所述第二生成式对抗网络模型包括第二生成器和第二判别器;
处理模块,用于将训练数据分别输入至所述第一生成式对抗网络模型、和所述第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果;
生成模块,用于基于所述第一输出结果和所述第二输出结果,生成交叉判别损失,所述交叉判别损失为所述第一生成式对抗网络模型中第一生成器和第一判别器、与所述第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失;
迭代模块,用于基于所述交叉判别损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
则此时,所述生成模块,具体可以用于基于所述第一生成结果、以及所述第二判别结果,生成以固定的所述第一生成器监督所述第二判别器的第一交叉判别损失;基于所述第二生成结果、以及所述第一判别结果,生成以固定的所述第一判别器监督所述第二生成器的第二交叉判别损失;基于所述第二生成结果、以及所述第二判别结果,生成利用所述第二生成器自监督所述第二判别器的第三交叉判别损失;融合所述第一交叉判别损失、所述第二交叉判别损失、以及所述第三交叉判别损失,得到交叉判别损失。
可选的,在一些实施例中,所述迭代模块可以包括生成子模块和第一迭代子模块,如下:
生成子模块,用于基于所述第一输出结果和所述第二输出结果,生成知识蒸馏损失;
第一迭代子模块,用于基于所述交叉判别损失、以及所述知识蒸馏损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
则此时,所述生成子模块,具体可以用于基于所述第一生成结果和所述第二生成结果,生成表征所述第一生成器和所述第二生成器之间差异的生成器损失;基于所述第一判别结果和所述第二判别结果,生成表征所述第一判别器和所述第二判别器之间差异的判别器损失;融合所述生成器损失、以及所述判别器损失,得到知识蒸馏损失。
则此时,所述迭代模块,具体可以用于基于所述交叉判别损失、以及所述知识蒸馏损失,通过梯度下降算法迭代更新所述第二生成式对抗网络模型的网络模型参数;基于自适应参数调试函数,迭代优化目标损失函数中的权重参数,其中,所述目标损失函数为基于所述交叉判别损失、以及所述知识蒸馏损失构建的损失函数;循环执行上述网络模型参数和权重参数的更新步骤直至收敛,得到压缩后的目标生成式对抗网络模型。
则此时,所述迭代模块,具体还可以用于基于不等式约束条件,确定求解所述目标损失函数中权重参数的求解条件;基于所述求解条件,确定优化所述权重参数的自适应参数调试函数。
此外,本申请实施例还提供一种计算机可读存储介质,所述计算机可读存储介质存储有多条指令,所述指令适于处理器进行加载,以执行本申请实施例提供的任一种网络模型压缩方法中的步骤。
此外,本申请实施例还提供一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如本申请实施例提供的任一种网络模型压缩方法中的步骤。
本申请实施例可以获取训练完成的第一生成式对抗网络模型,第一生成式对抗网络模型包括第一生成器和第一判别器,初始化第二生成式对抗网络模型,第二生成式对抗网络模型与第一生成式对抗网络模型针对相同的模型任务,第二生成式对抗网络模型的网络模型参数量小于第一生成式对抗网络模型的网络模型参数量,且第二生成式对抗网络模型包括第二生成器和第二判别器,将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果,基于第一输出结果和第二输出结果,生成交叉判别损失,交叉判别损失为第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。该方案可以提升第二生成式对抗网络模型、与已经训练完成的第一生成式对抗网络模型之间的相似程度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的网络模型压缩系统的场景示意图;
图2是本申请实施例提供的网络模型压缩方法的第一流程图;
图3是本申请实施例提供的网络模型压缩方法的第二流程图;
图4是本申请实施例提供的训练过程中的损失函数的示意图;
图5是本申请实施例提供的网络模型压缩装置的结构示意图;
图6是本申请实施例提供的电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本申请实施例提供一种网络模型压缩方法、装置、存储介质和电子设备。具体地,本申请实施例的网络模型压缩方法可以由电子设备执行,其中,该电子设备可以为终端或者服务器等设备,该终端可以为手机、平板电脑、笔记本电脑、智能电视、穿戴式智能设备、个人计算机(PC,Personal Computer)等设备。其中,终端可以包括客户端,该客户端可以是视频客户端或浏览器客户端等,服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云计算服务的云服务器。
例如,参见图1,以该网络模型压缩方法由电子设备执行为例,该电子设备可以获取训练完成的第一生成式对抗网络模型,第一生成式对抗网络模型包括第一生成器和第一判别器,初始化第二生成式对抗网络模型,第二生成式对抗网络模型与第一生成式对抗网络模型针对相同的模型任务,第二生成式对抗网络模型的网络模型参数量小于第一生成式对抗网络模型的网络模型参数量,且第二生成式对抗网络模型包括第二生成器和第二判别器,将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果,基于第一输出结果和第二输出结果,生成交叉判别损失,交叉判别损失为第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
本申请实施例提供的网络模型压缩方法涉及人工智能领域中的机器学习方向。本申请实施例可以获取训练完成的第一生成式对抗网络模型,初始化第二生成式对抗网络模型,然后利用交叉判别损失对第二生成式对抗网络模型进行训练,训练完成后得到压缩后的目标生成式对抗网络模型。
其中,人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。其中,人工智能软件技术主要包括计算机视觉技术、机器学习/深度学习等方向。
其中,机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、式教学习等技术。
以下分别进行详细说明。需说明的是,以下实施例的描述顺序不作为对实施例优选顺序的限定。
本申请实施例提供了一种网络模型压缩方法,该方法可以由终端或服务器执行,也可以由终端和服务器共同执行;本申请实施例以网络模型压缩方法由服务器执行为例来进行说明,如图2所示,该网络模型压缩方法的具体流程可以如下:
201、获取训练完成的第一生成式对抗网络模型。
其中,生成式对抗网络模型(GAN,Generative Adversarial Networks)是一种深度学习模型。生成式对抗网络模型通过框架中生成模型(Generative Model)和判别模型(Discriminative Model)之间的互相博弈学习产生相当好的输出,因此,生成式对抗网络模型在图像/视频翻译、文本/图像/视频生成等任务上都有重要的应用。比如,本申请中的第一生成式对抗网络模型可以包括第一生成器和第一判别器,其中,第一生成器即为生成式对抗网络模型中的生成模型,第一判别器即为生成式对抗网络模型中的判别模型。
其中,知识蒸馏是一种基于神经网络的信息提取方式,同时也是一种有效的网络压缩方式,通过集成或者大规模训练的方式生成一个教师网络,然后将该教师网络的输出标签进行软化,从而增加不同类别之间的信息量,使得对于不同模型分类任务的兼容性更强。当面临实际问题的时候,教师网络会指导训练学生网络生成相应模型来解决实际的分类或识别问题,该学生网络可以有效地将教师网络中优秀的分类能力和预测能力继承下来,并且减少了教师网络的冗余性和复杂度,同时又提高了学生网络的性能。
其中,在本申请实施例中,可以将知识蒸馏方法中的教师网络称为第一生成式对抗网络模型,将知识蒸馏方法中的学生网络称为第二生成式对抗网络模型。
其中,生成器是利用给定的隐含信息,来随机产生观测数据的网络模型,比如,可以给定一系列猫的图像,利用生成器生成一张新的猫的图像。
其中,判别器是能够对输入变量进行预测的网络模型。比如,可以给定一张图像,并利用判别器判断这张图中的动物是猫还是狗。
在实际应用中,比如,生成式对抗网络模型在图像/视频翻译、文本/图像/视频生成等任务上有重要的应用。但是目前的生成式对抗网络模型结构都比较复杂并且包含大量的参数,如何将训练好的生成式模型进行合适的压缩从而部署到服务器或者移动设备上是实际生产中一个非常重要优化方向。本申请可以通过知识蒸馏的方法来进行生成式对抗网络模型的压缩,通过知识蒸馏得到的生成式对抗网络模型大大减少了网络模型参数量,并且结构简单,容易部署。
由于本申请是利用知识蒸馏的方法进行网络模型的压缩,因此,需要首先获取到压缩模型时能够作为基准的教师网络,也即已经训练完成的第一生成式对抗网络模型,如图4所示,该第一生成式对抗网络模型包括训练后固定的第一生成器GT、和训练后固定的第一判别器DT,该第一生成式对抗网络模型精度高,并且网络模型参数量多。
202、初始化第二生成式对抗网络模型。
其中,第二生成式对抗网络模型与第一生成式对抗网络模型针对相同的模型任务,第二生成式对抗网络模型的网络模型参数量小于第一生成式对抗网络模型的网络模型参数量,且第二生成式对抗网络模型包括第二生成器和第二判别器。
在实际应用中,比如,可以预先设置一个第二生成式对抗网络模型作为知识蒸馏方法中的学生网络,该第二生成式对抗网络模型的网络模型参数量少于第一生成式对抗网络模型的网络模型参数量,如图4所示,该第二生成式对抗网络模型包括第二生成器GS和第二判别器DS,并且可以将第二生成式对抗网络模型进行网络参数初始化为WGS和WDS。
203、将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果。
在实际应用中,比如,可以将训练数据输入到第一生成式对抗网络模型中,得到第一输出结果,并且将训练数据输入到第二生成式对抗网络模型中,得到第二输出结果。
204、基于第一输出结果和第二输出结果,生成交叉判别损失。
其中,交叉判别损失为第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失。
在实际应用中,可以利用交叉判别损失来提升最终生成的第二生成式对抗网络模型与第一生成式对抗网络模型之间的相似程度,并解决知识蒸馏方法中教师网络和学生网络之间不匹配的问题。其中,该交叉判别损失的逻辑在于对第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间进行交叉监督,比如,该交叉判别损失可以利用第一生成式对抗网络模型的第一判别器DT来监督第二生成式对抗网络模型的第二生成器GS,利用第一生成式对抗网络模型的第一生成器GT来监督第二生成式对抗网络模型的第二判别器DS,同时利用第二生成式对抗网络模型的第二判别器DS来自监督第二生成式对抗网络模型的第二生成器GS。
在一实施例中,具体地,步骤“基于所述第一输出结果和所述第二输出结果,生成交叉判别损失”,可以包括:
基于所述第一生成结果、以及所述第二判别结果,生成以固定的所述第一生成器监督所述第二判别器的第一交叉判别损失;
基于所述第二生成结果、以及所述第一判别结果,生成以固定的所述第一判别器监督所述第二生成器的第二交叉判别损失;
基于所述第二生成结果、以及所述第二判别结果,生成利用所述第二生成器自监督所述第二判别器的第三交叉判别损失;
融合所述第一交叉判别损失、所述第二交叉判别损失、以及所述第三交叉判别损失,得到交叉判别损失。
其中,由于第一生成式对抗网络模型包括第一生成器和第一判别器,第二生成式对抗网络模型包括第二生成器和第二判别器,因此,基于第一生成式对抗网络模型生成的第一输出结果可以包括第一生成结果和第一判别结果,基于第二生成式对抗网络模型生成的第二输出结果可以包括第二生成结果和第二判别结果。
在一实施例中,比如,可以提出一种新的损失函数:交叉判别损失函数(crossgenerator-discriminator loss)来解决学生网络无法有效逼近教师网络的问题,该交叉判别损失函数的公式可以如下:
LCGD(GS,DS)=γ1LGAN(GS,DT)+γ2LGAN(GT,DS)+γ3LGAN(GS,DS)
其中,该交叉判别损失函数的公式由三项组成,LGAN(GS,DT)表示利用第一生成式对抗网络模型的第一判别器来评估第二生成式对抗网络模型的第二生成器的好坏,LGAN(GT,DS)表示利用第一生成式对抗网络模型的第一生成器来评估第二生成式对抗网络模型的第二判别器的好坏,LGAN(GS,DS)表示利用第二生成式对抗网络模型的第二判别器对第二生成器进行自我监督。γ1、γ2和γ3表示损失函数中的权重系数。本申请实施例中的交叉判别损失函数的形式还可以推广为其他类型的函数形式。
在一实施例中,还可以引入多种不同类型的损失函数来监督第二生成式对抗网络模型的第二生成器和第二判别器。具体地,步骤“基于所述交叉判别损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型”,可以包括:
基于所述第一输出结果和所述第二输出结果,生成知识蒸馏损失;
基于所述交叉判别损失、以及所述知识蒸馏损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
其中,知识蒸馏损失可以表征第一生成式对抗网络模型的第一生成器与第二生成式对抗网络模型的第二生成器之间的差异、以及第一生成式对抗网络模型的第一判别器与第二生成式对抗网络模型的第二判别器之间的差异。
在实际应用中,比如,可以利用知识蒸馏损失、以及交叉判别损失,对第二生成式对抗网络模型的网络模型参数进行更新,以便得到与第一生成式对抗网络模型更为相似的第二生成式对抗网络模型。
在一实施例中,具体地,步骤“基于所述第一输出结果和所述第二输出结果,生成知识蒸馏损失”,可以包括:
基于所述第一生成结果和所述第二生成结果,生成表征所述第一生成器和所述第二生成器之间差异的生成器损失;
基于所述第一判别结果和所述第二判别结果,生成表征所述第一判别器和所述第二判别器之间差异的判别器损失;
融合所述生成器损失、以及所述判别器损失,得到知识蒸馏损失。
在实际应用中,比如,还可以引入度量第一生成式对抗网络模型的第一生成器与第二生成式对抗网络模型的第二生成器之间差异的生成器损失、以及度量第一生成式对抗网络模型的第一判别器与第二生成式对抗网络模型的第二判别器之间差异的判别器损失,其中,生成器损失对应的损失函数公式可以如下:
其中,判别器损失对应的损失函数公式可以如下:
如图4所示,可以将知识蒸馏损失与交叉判别损失进行综合,得到综合的损失函数,其公式可以如下:
其中,λi和μi表示损失函数中的权重系数。
205、基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
在实际应用中,比如,在根据交叉判别损失确定了损失函数之后,可以利用自适应随机梯度算法(Adam)来更新第二生成式对抗网络模型的第二生成器和第二判别器的网络模型参数。其中,在采用反向传播算法中基于Adam的梯度下降法优化网络参数的过程中,可以将学习率参数设定为η。其中,本申请实施例中用于训练第二生成式对抗网络模型的自适应随机梯度算法(Adam)还可以替换为其他训练神经网络的算法,如随机梯度下降(SGD)、AMSGrad等等。
在一实施例中,可以通过自适应调试损失函数权重系数的方法降低手动调参带来的额外成本,同时提高网络训练效率。具体地,步骤“基于所述交叉判别损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型”,可以包括:
基于所述交叉判别损失、以及所述知识蒸馏损失,通过梯度下降算法迭代更新所述第二生成式对抗网络模型的网络模型参数;
基于自适应参数调试函数,迭代优化目标损失函数中的权重参数,其中,所述目标损失函数为基于所述交叉判别损失、以及所述知识蒸馏损失构建的损失函数;
循环执行上述网络模型参数和权重参数的更新步骤直至收敛,得到压缩后的目标生成式对抗网络模型。
在实际应用中,基于交叉判别损失、以及知识蒸馏损失,可以构建第二生成式对抗网络模型的损失函数,该损失函数中包括不同的权重系数λ、μ、γ。由于在第二生成式对抗网络模型中引入了不同类型的损失函数,使得网络模型训练难度大大提高,并且由于第二生成式对抗网络模型的minmax结构特点,缺少有效的度量准则来指导调试不同损失函数权重系数λ、μ、γ。因此可以通过自适应调试损失函数权重系数的方法降低手动调参带来的额外成本,同时提高网络训练效率。
比如,本申请可以基于第二生成式对抗网络模型的最优性条件,来设置第二生成式对抗网络模型不同损失函数的权重系数。在第二生成式对抗网络模型的训练过程中,可以交替更新第二生成式对抗网络模型的网络模型参数、以及不同损失函数的权重系数,以便实现第二生成式对抗网络模型中不同损失函数的权重系数自动化调整为最优权重,同时大幅降低网络模型的调参成本。
在一实施例中,具体地,该网络模型压缩方法还可以包括:
基于不等式约束条件,确定求解所述目标损失函数中权重参数的求解条件;
基于所述求解条件,确定优化所述权重参数的自适应参数调试函数。
在实际应用中,比如,可以从第二生成式对抗网络模型的KKT条件(也即不等式约束条件)出发,求解第二生成式对抗网络模型,等价于找到第二生成式对抗网络模型的KKT系统的零点,也即求解条件,其中,求解条件公式可以如下:
可以看到,第二生成式对抗网络模型的第二生成器、第二判别器、以及损失函数的权重系数λ、μ、γ,均影响了KKT系统接近0的快慢。可以通过极小化KKT系统的误差,来估计损失函数的权重系数,如可以将其定义为以下的凸优化问题,自适应参数调试函数的公式可以如下:
结合上述的损失函数的权重系数、以及第二生成式对抗网络模型的训练模型,可以将第二生成式对抗网络模型结构优化为以下双层优化的问题:
在一实施例中,比如,可以通过反向传播算法计算出度量函数的关于GS、DS的梯度,然后根据这些梯度的大小来更新第二生成式对抗网络模型中各类型损失函数的权重,从而得到第二生成式对抗网络模型的损失函数。再利用自适应随机梯度算法(Adam)来更新第二生成式对抗网络模型的第二生成器和第二判别器的网络参数。将以上流程交替进行,直至第二生成式对抗网络模型的网络参数收敛。
其中,可以将第二生成式对抗网络模型的训练算法总结如下:
设置网络结构和参数:给定数据X,给定第一生成式对抗网络模型的网络结构。给定小参数量的第二生成式对抗网络模型的第二生成器GS和第二判别器DS,并将网络参数初始化为WGS和WDS。
输出训练好的第二生成式对抗网络模型的第二生成器GS和第二判别器DS的模型参数
其中,通过上述流程可知,本申请在交替的训练第二生成式对抗网络模型、以及更新损失函数的权重系数。同时,网络模型的权重系数根据当前模型的网络参数WGS和WDS、以及当前的训练数据自动更新。利用本申请的模型压缩方法,可以大幅减缓网络模型的训练难度。其中,估计第二生成式对抗网络模型损失函数权重的Frank-wolfe算法也可以替换为其他的一阶优化算法,如投影梯度算法,增广拉格朗日函数算法等。
在一实施例中,可以利用该网络模型压缩方法得到一个参数量较少的第二生成式对抗网络模型,并且其效果与第一生成式对抗网络模型的效果相当。比如,在人像的图像/视频翻译的任务上,可以首先训练一个高精度的第一生成式对抗网络模型,并通过本申请的网络模型压缩方法得到一个效果好参数量少的第二生成式对抗网络模型。同时,本申请提出了交叉判别损失,以解决知识蒸馏技术中教师网络和学生网络之间不匹配的问题,另外,本申请还采用了自动化调整损失函数的权重系数的方法,降低了学生网络的训练难度。因此,利用本申请的网络模型压缩方法得到的第二生成式对抗网络模型,能够快速有效的部署到服务器和移动设备上,并显著的加速模型推断的效果与实时的推流速度。
比如,可以将该第二生成式对抗网络模型部署在显卡上,可以达到100FPS的推断速度,同时也可以将第二生成式对抗网络模型部署在移动端设备上,并达到实时的推断速度,这种方法可以使得在移动端和服务器端均能实时生成虚拟主播、虚拟解说、虚拟教师等等,大幅降低了人力成本,并为虚拟人物的落地和推广提供了强有力的支持。
在一实施例中,本申请的网络模型压缩方法,除了可以应用在文本/图像/视频翻译任务中,还可以用到其他的回归任务的模型压缩问题上,比如实时超分辨率、目标检测、语义分割等等。
由上可知,本申请实施例可以获取训练完成的第一生成式对抗网络模型,第一生成式对抗网络模型包括第一生成器和第一判别器,初始化第二生成式对抗网络模型,第二生成式对抗网络模型与第一生成式对抗网络模型针对相同的模型任务,第二生成式对抗网络模型的网络模型参数量小于第一生成式对抗网络模型的网络模型参数量,且第二生成式对抗网络模型包括第二生成器和第二判别器,将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果,基于第一输出结果和第二输出结果,生成交叉判别损失,交叉判别损失为第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。该方案可以通过网络模型的压缩,得到参数量较少且效果与第一生成式对抗网络模型相当的第二生成式对抗网络模型。同时,本方案提出的交叉判别损失可以解决知识蒸馏技术中教师网络和学生网络之间不匹配的问题,另外,本方案还采用了自动化调整损失函数的权重系数的方法,降低了学生网络的训练难度。因此,利用本申请的网络模型压缩方法得到的第二生成式对抗网络模型,能够快速有效的部署到服务器和移动设备上,并显著地提升网络模型的推理速度,大幅降低了人力成本,并为虚拟人物的落地和推广提供了强有力的支持。
根据前面实施例所描述的方法,以下将以该网络模型压缩装置具体集成在电子设备中举例作进一步详细说明。
参考图3,本申请实施例的网络模型压缩方法的具体流程可以如下:
301、给定训练完成教师GAN模型的生成器GT和判别器DT,并将其固定。
在实际应用中,比如,可以在给定数据上训练得到一个精度高,参数量多的教师GAN模型,并将预先训练好的教师GAN模型的生成器记作GT,将教师GAN模型的判别器记作DT。
302、给定小参数量的学生GAN模型的生成器GS和判别器DS,并将网络参数初始化为WGS和WDS。
在实际应用中,比如,可以预先设置一个参数量较少的学生GAN模型,学生GAN模型的生成器记作GS,学生GAN模型的判别器记作DS。在学生GAN模型的训练过程中,可以将训练数据输入到教师GAN网络中得到GT,将训练数据输入到学生GAN网络中得到GS。
303、确定交叉判别损失函数和知识蒸馏损失函数。
在实际应用中,比如,可以引入交叉判别损失,利用教师GAN模型的生成器GT来监督学生GAN模型的判别器DS,教师GAN模型的判别器DT来监督学生GAN模型的生成器GS,以及学生GAN模型的生成器GS来自监督学生GAN模型的判别器DS。同时,还可以引入知识蒸馏损失,利用教师GAN模型的生成器GT来监督学生GAN模型的生成器GS,教师GAN模型的判别器DT来监督学生GAN模型的判别器DS。
其中,交叉判别损失函数公式可以如下:
LCGD(GS,DS)=γ1LGAN(GS,DT)+γ2LGAN(GT,DS)+γ3LGAN(GS,DS)
其中,该交叉判别损失函数的公式由三项组成,LGAN(GS,DT)表示利用教师GAN模型的判别器来评估学生GAN模型的生成器的好坏,LGAN(GT,DS)表示利用教师GAN模型的生成器来评估学生GAN模型的判别器的好坏,LGAN(GS,DS)表示利用学生GAN模型的判别器对生成器进行自我监督。γ1、γ2和γ3表示损失函数中的权重系数。
其中,知识蒸馏损失函可以包括度量教师GAN模型的生成器与学生GAN模型的生成器之间差异的生成器损失、以及度量教师GAN模型的判别器与学生GAN模型的判别器之间差异的判别器损失。
其中,生成器损失对应的损失函数公式可以如下:
其中,判别器损失对应的损失函数公式可以如下:
可以将知识蒸馏损失与交叉判别损失进行综合,得到综合的损失函数,其公式可以如下:
其中,λi和μi表示损失函数中的权重系数。
304、通过自适应随机梯度更新学生GAN模型的参数WGS和WDS。
305、自适应调试损失函数的权重系数。
在实际应用中,比如,可以通过反向传播算法计算出这些度量函数的关于GS和DS的梯度。随后,我们根据这些梯度大小来更新学生GAN模型中各类型损失函数的权重系数,从而得到学生GAN模型的加权损失函数。
其中,可以从学生GAN模型的KKT条件出发求解学生GAN模型,等价于找到学生GAN模型的KKT系统的零点:
可以看到,学生GAN模型的生成器、判别器、以及损失函数的权重系数λ、μ、γ,均影响了KKT系统接近0的快慢。可以通过极小化KKT系统的误差,来估计损失函数的权重系数,如可以将其定义为以下的凸优化问题,自适应参数调试函数的公式可以如下:
结合上述的损失函数的权重系数、以及第二生成式对抗网络模型的训练模型,可以将第二生成式对抗网络模型结构优化为以下双层优化的问题:
306、交替的更新学生GAN模型的参数、以及更新损失函数的权重系数。
307、当学生GAN模型的网络参数收敛时,得到目标学生GAN模型。
在实际应用中,比如,可以将学生GAN模型的训练算法总结如下:
设置网络结构和参数:给定数据X,给定教师GAN模型的网络结构。给定小参数量的学生GAN模型的生成器GS和判别器DS,并将网络参数初始化为WGS和WDS。
输出训练好的学生GAN模型的生成器GS和判别器DS的模型参数
由上可知,本申请实施例可以通过电子设备给定训练完成教师GAN模型的生成器GT和判别器DT,并将其固定,给定小参数量的学生GAN模型的生成器GS和判别器DS,并将网络参数初始化为WGS和WDS,确定交叉判别损失函数和知识蒸馏损失函数,通过自适应随机梯度更新学生GAN模型的参数WGS和WDS,自适应调试损失函数的权重系数,交替的更新学生GAN模型的参数、以及更新损失函数的权重系数,当学生GAN模型的网络参数收敛时,得到目标学生GAN模型。该方案可以通过网络模型的压缩,得到参数量较少且效果与第一生成式对抗网络模型相当的第二生成式对抗网络模型。同时,本方案提出的交叉判别损失可以解决知识蒸馏技术中教师网络和学生网络之间不匹配的问题,另外,本方案还采用了自动化调整损失函数的权重系数的方法,降低了学生网络的训练难度。因此,利用本申请的网络模型压缩方法得到的第二生成式对抗网络模型,能够快速有效的部署到服务器和移动设备上,并显著地提升网络模型的推理速度,大幅降低了人力成本,并为虚拟人物的落地和推广提供了强有力的支持。
为了更好地实施以上方法,相应的,本申请实施例还提供一种网络模型压缩装置,该网络模型压缩装置可以集成在电子设备中,参考图5,该网络模型压缩装置包括获取模块51、初始化模块52、处理模块53、生成模块54和迭代模块55,如下:
获取模块51,用于获取训练完成的第一生成式对抗网络模型,所述第一生成式对抗网络模型包括第一生成器和第一判别器;
初始化模块52,用于初始化第二生成式对抗网络模型,所述第二生成式对抗网络模型与所述第一生成式对抗网络模型针对相同的模型任务,所述第二生成式对抗网络模型的网络模型参数量小于所述第一生成式对抗网络模型的网络模型参数量,且所述第二生成式对抗网络模型包括第二生成器和第二判别器;
处理模块53,用于将训练数据分别输入至所述第一生成式对抗网络模型、和所述第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果;
生成模块54,用于基于所述第一输出结果和所述第二输出结果,生成交叉判别损失,所述交叉判别损失为所述第一生成式对抗网络模型中第一生成器和第一判别器、与所述第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失;
迭代模块55,用于基于所述交叉判别损失、以及所述知识蒸馏损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
在一实施例中,所述生成模块54可以具体用于:
基于所述第一生成结果、以及所述第二判别结果,生成以固定的所述第一生成器监督所述第二判别器的第一交叉判别损失;
基于所述第二生成结果、以及所述第一判别结果,生成以固定的所述第一判别器监督所述第二生成器的第二交叉判别损失;
基于所述第二生成结果、以及所述第二判别结果,生成利用所述第二生成器自监督所述第二判别器的第三交叉判别损失;
融合所述第一交叉判别损失、所述第二交叉判别损失、以及所述第三交叉判别损失,得到交叉判别损失。
在一实施例中,所述迭代模块55可以包括生成子模块和第一迭代子模块,如下:
生成子模块,用于基于所述第一输出结果和所述第二输出结果,生成知识蒸馏损失;
第一迭代子模块,用于基于所述交叉判别损失、以及所述知识蒸馏损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
在一实施例中,所述生成子模块可以具体用于:
基于所述第一生成结果和所述第二生成结果,生成表征所述第一生成器和所述第二生成器之间差异的生成器损失;
基于所述第一判别结果和所述第二判别结果,生成表征所述第一判别器和所述第二判别器之间差异的判别器损失;
融合所述生成器损失、以及所述判别器损失,得到知识蒸馏损失。
在一实施例中,所述迭代模块55可以具体用于:
基于所述交叉判别损失、以及所述知识蒸馏损失,通过梯度下降算法迭代更新所述第二生成式对抗网络模型的网络模型参数;
基于自适应参数调试函数,迭代优化目标损失函数中的权重参数,其中,所述目标损失函数为基于所述交叉判别损失、以及所述知识蒸馏损失构建的损失函数;
循环执行上述网络模型参数和权重参数的更新步骤直至收敛,得到压缩后的目标生成式对抗网络模型。
在一实施例中,所述迭代模块55还可以具体用于:
基于不等式约束条件,确定求解所述目标损失函数中权重参数的求解条件;
基于所述求解条件,确定优化所述权重参数的自适应参数调试函数。
具体实施时,以上各个单元可以作为独立的实体来实现,也可以进行任意组合,作为同一或若干个实体来实现,以上各个单元的具体实施可参见前面的方法实施例,在此不再赘述。
由上可知,本申请实施例可以通过获取模块51获取训练完成的第一生成式对抗网络模型,第一生成式对抗网络模型包括第一生成器和第一判别器,通过初始化模块52初始化第二生成式对抗网络模型,第二生成式对抗网络模型与第一生成式对抗网络模型针对相同的模型任务,第二生成式对抗网络模型的网络模型参数量小于第一生成式对抗网络模型的网络模型参数量,且第二生成式对抗网络模型包括第二生成器和第二判别器,通过处理模块53将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果,通过生成模块54基于第一输出结果和第二输出结果,生成交叉判别损失,交叉判别损失为第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,通过迭代模块55基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。该方案可以通过网络模型的压缩,得到参数量较少且效果与第一生成式对抗网络模型相当的第二生成式对抗网络模型。同时,本方案提出的交叉判别损失可以解决知识蒸馏技术中教师网络和学生网络之间不匹配的问题,另外,本方案还采用了自动化调整损失函数的权重系数的方法,降低了学生网络的训练难度。因此,利用本申请的网络模型压缩方法得到的第二生成式对抗网络模型,能够快速有效的部署到服务器和移动设备上,并显著地提升网络模型的推理速度,大幅降低了人力成本,并为虚拟人物的落地和推广提供了强有力的支持。
本申请实施例还提供一种电子设备,该电子设备可以集成本申请实施例所提供的任一种网络模型压缩装置。
例如,如图6所示,其示出了本申请实施例所涉及的电子设备的结构示意图,具体来讲:
该电子设备可以包括一个或者一个以上处理核心的处理器61、一个或一个以上计算机可读存储介质的存储器62、电源63和输入单元64等部件。本领域技术人员可以理解,图6中示出的电子设备结构并不构成对电子设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
其中:
处理器61是该电子设备的控制中心,利用各种接口和线路连接整个电子设备的各个部分,通过运行或执行存储在存储器62内的软件程序和/或模块,以及调用存储在存储器62内的数据,执行电子设备的各种功能和处理数据,从而对电子设备进行整体检测。可选的,处理器61可包括一个或多个处理核心;优选的,处理器61可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、玩家界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器61中。
存储器62可用于存储软件程序以及模块,处理器61通过运行存储在存储器62的软件程序以及模块,从而执行各种功能应用以及数据处理。存储器62可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据电子设备的使用所创建的数据等。此外,存储器62可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。相应地,存储器62还可以包括存储器控制器,以提供处理器61对存储器62的访问。
电子设备还包括给各个部件供电的电源63,优选的,电源63可以通过电源管理系统与处理器61逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。电源63还可以包括一个或一个以上的直流或交流电源、再充电系统、电源故障检测电路、电源转换器或者逆变器、电源状态指示器等任意组件。
该电子设备还可包括输入单元64,该输入单元64可用于接收输入的数字或字符信息,以及产生与玩家设置以及功能控制有关的键盘、鼠标、操作杆、光学或者轨迹球信号输入。
尽管未示出,电子设备还可以包括显示单元等,在此不再赘述。具体在本实施例中,电子设备中的处理器61会按照如下的指令,将一个或一个以上的应用程序的进程对应的可执行文本加载到存储器62中,并由处理器61来运行存储在存储器62中的应用程序,从而实现各种功能,如下:
获取训练完成的第一生成式对抗网络模型,第一生成式对抗网络模型包括第一生成器和第一判别器,初始化第二生成式对抗网络模型,第二生成式对抗网络模型与第一生成式对抗网络模型针对相同的模型任务,第二生成式对抗网络模型的网络模型参数量小于第一生成式对抗网络模型的网络模型参数量,且第二生成式对抗网络模型包括第二生成器和第二判别器,将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果,基于第一输出结果和第二输出结果,生成交叉判别损失,交叉判别损失为第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
以上各个操作的具体实施可参见前面的实施例,在此不再赘述。
由上可知,本申请实施例可以获取训练完成的第一生成式对抗网络模型,第一生成式对抗网络模型包括第一生成器和第一判别器,初始化第二生成式对抗网络模型,第二生成式对抗网络模型与第一生成式对抗网络模型针对相同的模型任务,第二生成式对抗网络模型的网络模型参数量小于第一生成式对抗网络模型的网络模型参数量,且第二生成式对抗网络模型包括第二生成器和第二判别器,将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果,基于第一输出结果和第二输出结果,生成交叉判别损失,交叉判别损失为第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。该方案可以通过网络模型的压缩,得到参数量较少且效果与第一生成式对抗网络模型相当的第二生成式对抗网络模型。同时,本方案提出的交叉判别损失可以解决知识蒸馏技术中教师网络和学生网络之间不匹配的问题,另外,本方案还采用了自动化调整损失函数的权重系数的方法,降低了学生网络的训练难度。因此,利用本申请的网络模型压缩方法得到的第二生成式对抗网络模型,能够快速有效的部署到服务器和移动设备上,并显著地提升网络模型的推理速度,大幅降低了人力成本,并为虚拟人物的落地和推广提供了强有力的支持。
本领域普通技术人员可以理解,上述实施例的各种方法中的全部或部分步骤可以通过指令来完成,或通过指令控制相关的硬件来完成,该指令可以存储于一计算机可读存储介质中,并由处理器进行加载和执行。
为此,本申请实施例提供一种电子设备,其中存储有多条指令,该指令能够被处理器进行加载,以执行本申请实施例所提供的任一种网络模型压缩方法中的步骤。例如,该指令可以执行如下步骤:
获取训练完成的第一生成式对抗网络模型,第一生成式对抗网络模型包括第一生成器和第一判别器,初始化第二生成式对抗网络模型,第二生成式对抗网络模型与第一生成式对抗网络模型针对相同的模型任务,第二生成式对抗网络模型的网络模型参数量小于第一生成式对抗网络模型的网络模型参数量,且第二生成式对抗网络模型包括第二生成器和第二判别器,将训练数据分别输入至第一生成式对抗网络模型、和第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果,基于第一输出结果和第二输出结果,生成交叉判别损失,交叉判别损失为第一生成式对抗网络模型中第一生成器和第一判别器、与第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,基于交叉判别损失,迭代更新第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
根据本申请的一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述网络模型压缩方面的各种可选实现方式中提供的方法。
以上各个操作的具体实施可参见前面的实施例,在此不再赘述。
其中,该存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,Random Access Memory)、磁盘或光盘等。
由于该存储介质中所存储的指令,可以执行本申请实施例所提供的任一种网络模型压缩方法中的步骤,因此,可以实现本申请实施例所提供的任一种网络模型压缩方法所能实现的有益效果,详见前面的实施例,在此不再赘述。
以上对本申请实施例所提供的一种网络模型压缩方法、装置、存储介质和电子设备进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。
Claims (9)
1.一种网络模型压缩方法,其特征在于,包括:
获取训练完成的第一生成式对抗网络模型,所述第一生成式对抗网络模型包括第一生成器和第一判别器;
初始化第二生成式对抗网络模型,所述第二生成式对抗网络模型与所述第一生成式对抗网络模型针对相同的模型任务,所述第二生成式对抗网络模型的网络模型参数量小于所述第一生成式对抗网络模型的网络模型参数量,且所述第二生成式对抗网络模型包括第二生成器和第二判别器;
将训练数据分别输入至所述第一生成式对抗网络模型、和所述第二生成式对抗网络模型中进行处理,得到第一输出结果和第二输出结果;
基于所述第一输出结果和所述第二输出结果,生成交叉判别损失,所述交叉判别损失为所述第一生成式对抗网络模型中第一生成器和第一判别器、与所述第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,包括:基于所述第一输出结果和所述第二输出结果,生成以固定的所述第一生成器监督所述第二判别器的第一交叉判别损失、以固定的所述第一判别器监督所述第二生成器的第二交叉判别损失、以及利用所述第二生成器自监督所述第二判别器的第三交叉判别损失;融合所述第一交叉判别损失、所述第二交叉判别损失、以及所述第三交叉判别损失,得到交叉判别损失;
基于所述交叉判别损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型,以便将所述压缩后的目标生成式对抗网络模型用于图像翻译任务、或视频翻译任务、或文本生成任务、或图像生成任务、或视频生成任务。
2.根据权利要求1所述的网络模型压缩方法,其特征在于,所述第一输出结果包括第一生成结果和第一判别结果,所述第二输出结果包括第二生成结果和第二判别结果;
基于所述第一输出结果和所述第二输出结果,生成以固定的所述第一生成器监督所述第二判别器的第一交叉判别损失、以固定的所述第一判别器监督所述第二生成器的第二交叉判别损失、以及利用所述第二生成器自监督所述第二判别器的第三交叉判别损失,包括:
基于所述第一生成结果、以及所述第二判别结果,生成以固定的所述第一生成器监督所述第二判别器的第一交叉判别损失;
基于所述第二生成结果、以及所述第一判别结果,生成以固定的所述第一判别器监督所述第二生成器的第二交叉判别损失;
基于所述第二生成结果、以及所述第二判别结果,生成利用所述第二生成器自监督所述第二判别器的第三交叉判别损失。
3.根据权利要求1所述的网络模型压缩方法,其特征在于,基于所述交叉判别损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型,包括:
基于所述第一输出结果和所述第二输出结果,生成知识蒸馏损失;
基于所述交叉判别损失、以及所述知识蒸馏损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型。
4.根据权利要求3所述的网络模型压缩方法,其特征在于,所述第一输出结果包括第一生成结果和第一判别结果,所述第二输出结果包括第二生成结果和第二判别结果;
基于所述第一输出结果和所述第二输出结果,生成知识蒸馏损失,包括:
基于所述第一生成结果和所述第二生成结果,生成表征所述第一生成器和所述第二生成器之间差异的生成器损失;
基于所述第一判别结果和所述第二判别结果,生成表征所述第一判别器和所述第二判别器之间差异的判别器损失;
融合所述生成器损失、以及所述判别器损失,得到知识蒸馏损失。
5.根据权利要求3所述的网络模型压缩方法,其特征在于,基于所述交叉判别损失、以及所述知识蒸馏损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型,包括:
基于所述交叉判别损失、以及所述知识蒸馏损失,通过梯度下降算法迭代更新所述第二生成式对抗网络模型的网络模型参数;
基于自适应参数调试函数,迭代优化目标损失函数中的权重参数,其中,所述目标损失函数为基于所述交叉判别损失、以及所述知识蒸馏损失构建的损失函数;
循环执行上述网络模型参数和权重参数的更新步骤直至收敛,得到压缩后的目标生成式对抗网络模型。
6.根据权利要求5所述的网络模型压缩方法,其特征在于,所述方法还包括:
基于不等式约束条件,确定求解所述目标损失函数中权重参数的求解条件;
基于所述求解条件,确定优化所述权重参数的自适应参数调试函数。
7.一种网络模型压缩装置,其特征在于,包括:
获取模块,用于获取训练完成的第一生成式对抗网络模型,所述第一生成式对抗网络模型包括第一生成器和第一判别器;
初始化模块,用于初始化第二生成式对抗网络模型,所述第二生成式对抗网络模型与所述第一生成式对抗网络模型针对相同的模型任务,所述第二生成式对抗网络模型的网络模型参数量小于所述第一生成式对抗网络模型的网络模型参数量,且所述第二生成式对抗网络模型包括第二生成器和第二判别器;
处理模块,用于将训练数据分别输入至所述第一生成式对抗网络模型、和所述第二生成式对抗网络模型进行处理,得到第一输出结果和第二输出结果;
生成模块,用于基于所述第一输出结果和所述第二输出结果,生成交叉判别损失,所述交叉判别损失为所述第一生成式对抗网络模型中第一生成器和第一判别器、与所述第二生成式对抗网络模型中第二生成器和第二判别器之间交叉监督所得的损失,包括:基于所述第一输出结果和所述第二输出结果,生成以固定的所述第一生成器监督所述第二判别器的第一交叉判别损失、以固定的所述第一判别器监督所述第二生成器的第二交叉判别损失、以及利用所述第二生成器自监督所述第二判别器的第三交叉判别损失;融合所述第一交叉判别损失、所述第二交叉判别损失、以及所述第三交叉判别损失,得到交叉判别损失;
迭代模块,用于基于所述交叉判别损失,迭代更新所述第二生成式对抗网络模型的网络模型参数,得到压缩后的目标生成式对抗网络模型,以便将所述压缩后的目标生成式对抗网络模型用于图像翻译任务、或视频翻译任务、或文本生成任务、或图像生成任务、或视频生成任务。
8.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,当所述计算机程序在计算机上运行时,使得所述计算机执行如权利要求1-6任一项所述的网络模型压缩方法。
9.一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,其中,所述处理器执行所述程序时实现如权利要求1至6任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010837744.4A CN112052948B (zh) | 2020-08-19 | 2020-08-19 | 一种网络模型压缩方法、装置、存储介质和电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010837744.4A CN112052948B (zh) | 2020-08-19 | 2020-08-19 | 一种网络模型压缩方法、装置、存储介质和电子设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112052948A CN112052948A (zh) | 2020-12-08 |
CN112052948B true CN112052948B (zh) | 2023-11-14 |
Family
ID=73600623
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010837744.4A Active CN112052948B (zh) | 2020-08-19 | 2020-08-19 | 一种网络模型压缩方法、装置、存储介质和电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112052948B (zh) |
Families Citing this family (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115146522A (zh) * | 2021-03-31 | 2022-10-04 | 西门子股份公司 | 模型训练方法、诊断方法、装置、电子设备和可读介质 |
CN113177612B (zh) * | 2021-05-24 | 2022-09-13 | 同济大学 | 一种基于cnn少样本的农业病虫害图像识别方法 |
CN113408265B (zh) * | 2021-06-22 | 2023-01-17 | 平安科技(深圳)有限公司 | 基于人机交互的语义解析方法、装置、设备及存储介质 |
CN113449851A (zh) * | 2021-07-15 | 2021-09-28 | 北京字跳网络技术有限公司 | 数据处理方法及设备 |
CN113570493B (zh) * | 2021-07-26 | 2024-07-16 | 京东科技信息技术有限公司 | 一种图像生成方法及装置 |
CN113780534B (zh) * | 2021-09-24 | 2023-08-22 | 北京字跳网络技术有限公司 | 网络模型的压缩方法、图像生成方法、装置、设备及介质 |
CN117808067A (zh) * | 2022-09-23 | 2024-04-02 | 华为技术有限公司 | 神经网络剪枝方法及装置 |
CN117953108B (zh) * | 2024-03-20 | 2024-07-05 | 腾讯科技(深圳)有限公司 | 图像生成方法、装置、电子设备和存储介质 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110084281A (zh) * | 2019-03-31 | 2019-08-02 | 华为技术有限公司 | 图像生成方法、神经网络的压缩方法及相关装置、设备 |
CN110390950A (zh) * | 2019-08-17 | 2019-10-29 | 杭州派尼澳电子科技有限公司 | 一种基于生成对抗网络的端到端语音增强方法 |
WO2019222401A2 (en) * | 2018-05-17 | 2019-11-21 | Magic Leap, Inc. | Gradient adversarial training of neural networks |
CN110796619A (zh) * | 2019-10-28 | 2020-02-14 | 腾讯科技(深圳)有限公司 | 一种图像处理模型训练方法、装置、电子设备及存储介质 |
CN110880036A (zh) * | 2019-11-20 | 2020-03-13 | 腾讯科技(深圳)有限公司 | 神经网络压缩方法、装置、计算机设备及存储介质 |
-
2020
- 2020-08-19 CN CN202010837744.4A patent/CN112052948B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019222401A2 (en) * | 2018-05-17 | 2019-11-21 | Magic Leap, Inc. | Gradient adversarial training of neural networks |
CN110084281A (zh) * | 2019-03-31 | 2019-08-02 | 华为技术有限公司 | 图像生成方法、神经网络的压缩方法及相关装置、设备 |
CN110390950A (zh) * | 2019-08-17 | 2019-10-29 | 杭州派尼澳电子科技有限公司 | 一种基于生成对抗网络的端到端语音增强方法 |
CN110796619A (zh) * | 2019-10-28 | 2020-02-14 | 腾讯科技(深圳)有限公司 | 一种图像处理模型训练方法、装置、电子设备及存储介质 |
CN110880036A (zh) * | 2019-11-20 | 2020-03-13 | 腾讯科技(深圳)有限公司 | 神经网络压缩方法、装置、计算机设备及存储介质 |
Non-Patent Citations (2)
Title |
---|
Angeline aguinaldo,et al.Compressing GANs using knowledge distillation.《cs.LG》.2019,1-10页. * |
基于生成对抗网络与知识蒸馏的人脸修复与表情识别;姜慧明;《中国优秀硕士学位论文全文数据库 信息科技辑》;第2022年卷(第08期);I138-499 * |
Also Published As
Publication number | Publication date |
---|---|
CN112052948A (zh) | 2020-12-08 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112052948B (zh) | 一种网络模型压缩方法、装置、存储介质和电子设备 | |
US11790238B2 (en) | Multi-task neural networks with task-specific paths | |
CN112329948B (zh) | 一种多智能体策略预测方法及装置 | |
CN111259738B (zh) | 人脸识别模型构建方法、人脸识别方法及相关装置 | |
WO2020159890A1 (en) | Method for few-shot unsupervised image-to-image translation | |
CN111898703B (zh) | 多标签视频分类方法、模型训练方法、装置及介质 | |
CN113361680A (zh) | 一种神经网络架构搜索方法、装置、设备及介质 | |
CN112287656B (zh) | 文本比对方法、装置、设备和存储介质 | |
CN113344184B (zh) | 用户画像预测方法、装置、终端和计算机可读存储介质 | |
Dai et al. | Hybrid deep model for human behavior understanding on industrial internet of video things | |
CN113609337A (zh) | 图神经网络的预训练方法、训练方法、装置、设备及介质 | |
CN116595356B (zh) | 时序信号预测方法、装置、电子设备及存储介质 | |
CN115168720A (zh) | 内容交互预测方法以及相关设备 | |
CN111282272A (zh) | 信息处理方法、计算机可读介质及电子设备 | |
CN114611692A (zh) | 模型训练方法、电子设备以及存储介质 | |
CN114861671A (zh) | 模型训练方法、装置、计算机设备及存储介质 | |
CN117633184A (zh) | 一种模型构建和智能回复方法、设备及介质 | |
CN110866609B (zh) | 解释信息获取方法、装置、服务器和存储介质 | |
CN113392867A (zh) | 一种图像识别方法、装置、计算机设备及存储介质 | |
CN117312979A (zh) | 对象分类方法、分类模型训练方法及电子设备 | |
CN116541507A (zh) | 一种基于动态语义图神经网络的视觉问答方法及系统 | |
CN115168722A (zh) | 内容交互预测方法以及相关设备 | |
WO2022127603A1 (zh) | 一种模型处理方法及相关装置 | |
CN117010480A (zh) | 模型训练方法、装置、设备、存储介质及程序产品 | |
CN112052386B (zh) | 信息推荐方法、装置和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |