CN114065834A - 一种模型训练方法、终端设备及计算机存储介质 - Google Patents

一种模型训练方法、终端设备及计算机存储介质 Download PDF

Info

Publication number
CN114065834A
CN114065834A CN202111164452.XA CN202111164452A CN114065834A CN 114065834 A CN114065834 A CN 114065834A CN 202111164452 A CN202111164452 A CN 202111164452A CN 114065834 A CN114065834 A CN 114065834A
Authority
CN
China
Prior art keywords
model
teacher
student
output value
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202111164452.XA
Other languages
English (en)
Other versions
CN114065834B (zh
Inventor
周翊民
黄仲浩
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Institute of Advanced Technology of CAS
Original Assignee
Shenzhen Institute of Advanced Technology of CAS
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Institute of Advanced Technology of CAS filed Critical Shenzhen Institute of Advanced Technology of CAS
Priority to CN202111164452.XA priority Critical patent/CN114065834B/zh
Priority claimed from CN202111164452.XA external-priority patent/CN114065834B/zh
Publication of CN114065834A publication Critical patent/CN114065834A/zh
Application granted granted Critical
Publication of CN114065834B publication Critical patent/CN114065834B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

本申请提供了一种模型训练方法、终端设备以及计算机存储介质。该模型训练方法包括:获取待训练图像;采用待训练图像训练多个阶段的老师模型;利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对待训练图像的识别结果;按照识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用学生模型的预测输出值和老师模型的预测输出值计算学生模型的损失值,以及按照损失值调整所述学生模型的模型参数。通过上述方式,本申请的模型训练方法采用多个阶段的老师模型分别对不同阶段的学生模型进行训练,从而使得学生模型在假扮老师模型的过程中不断接近甚至超过老师模型的性能。

Description

一种模型训练方法、终端设备及计算机存储介质
技术领域
本申请涉及人工智能应用技术领域,特别是涉及一种模型训练方法、终端设备以及计算机存储介质。
背景技术
在过去的几年里,深度学习已经成为人工智能成功的基础,包括计算机视觉中的各种应用,强化学习和自然语言处理。借助最近的许多技术,包括残差连接和批处理归一化,可以很容易地在强大的GPU(graphics processing unit,图形处理器)或CPU(centralprocessing unit,中央处理器)集群上训练具有数千层的深度模型。例如,对于一个流行的图像识别基准数百万数据集,它只需要不到10分钟的时间来训练一个ResNet(残差网络)模型;训练一个强大的BERT(Bidirectional Encoder Representation from Transformers)模型进行语言理解只需要不到一个半小时。大规模的深度模型已经取得了巨大的成功,但是巨大的计算复杂度和巨大的存储需求使得在实时应用中部署它们成为一个巨大的挑战,特别是在资源有限的设备上,如视频监控和非实时驾驶汽车。而且,对于一个具有数千层的深度模型(也称为笨重的模型),超过85%的权重即使它们消失,也对模型性能的影响不那么显著,这也说明了笨重的模型中包含了大量的冗余信息。
现有的方法大多直接使用训练好的老师模型对学生模型进行蒸馏,但学生-老师模型之间的性能差距过大,造成了知识在传递过程中的一定损失。其次,虽然大多数现有的方法实现了高精度的目标网络,但模型的巨大计算复杂度和复杂的模型训练过程使其在实际应用中面临巨大的挑战。其次,现有的大多数方法忽略了网络结构、通道数量和师生初始化条件差异对模型性能的影响,直接利用老师模型的相关输出来指导学生模型的训练,导致训练效率低下,训练效果不佳。
发明内容
本申请提供了一种模型训练方法、终端设备以及计算机存储介质。
本申请提供了一种模型训练方法,所述模型训练方法包括:
获取待训练图像;
采用所述待训练图像训练多个阶段的老师模型;
利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对所述待训练图像的识别结果;
按照所述识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用所述学生模型的预测输出值和所述老师模型的预测输出值计算所述学生模型的损失值,以及按照所述损失值调整所述学生模型的模型参数。
其中,所述采用所述待训练图像训练多个阶段的老师模型,包括:
在一个阶段的训练中,初始化所述老师模型的模型参数,将所述待训练图像输入所述老师模型;
按照第一预设迭代次数对所述老师模型进行训练后,冻结所述老师模型在这一阶段的模型参数;
在下一个阶段的训练中,重新初始化所述老师模型的模型参数,将所述待训练图像输入所述老师模型;
按照第二预设迭代次数对所述老师模型进行训练后,冻结所述老师模型在这一阶段的模型参数。
其中,所述利用所述学生模型的预测输出值和所述老师模型的预测输出值计算所述学生模型的损失值,包括:
利用所述学生模型的预测输出值和所述老师模型的预测输出值,计算所述学生模型的均方误差损失值;
利用所述学生模型的预测输出值和预设的目标输出值,计算所述学生模型的交叉熵损失值;
将所述均方误差损失值和所述交叉熵损失值组成得到所述学生模型的损失值。
其中,所述模型训练方法,还包括:
将同一阶段的老师模型和学生模型按照相同划分方式将网络模型划分为相同数量的卷积块;
利用所述老师模型相应卷积块的预测输出值对所述学生模型的每一个卷积块进行训练,获取所述学生模型中目标卷积块的预测输出值,以及所述老师模型中相应卷积块的预测输出值;
利用所述目标卷积块的预测输出值和所述相应卷积块的预测输出值计算所述目标卷积块的损失值,以及按照所述损失值调整所述目标卷积块的模型参数。
其中,所述模型训练方法,还包括:
将所述待训练图像分别输入所述老师模型和所述学生模型,获取所述老师模型中参考卷积块的预测输出值和所述学生模型中目标卷积块的预测输出值,其中,所述参考卷积块和所述目标卷积块在网络模型中的位置相同;
利用所述目标卷积块的预测输出值和所述参考卷积块的预测输出值计算所述目标卷积块的均方误差损失值,以及按照所述均方误差损失值调整所述目标卷积块的模型参数。
其中,所述模型训练方法,还包括:
将所述学生模型中的目标卷积块作为生成器;
将所述老师模型中的参考卷积块后面的卷积块作为鉴别器;
利用所述生成器和所述鉴别器组成生成对抗网络;
将所述待训练图像输入所述生成对抗网络,获取所述生成对抗网络的损失值;
利用所述生成对抗网络的损失值,调整所述目标卷积块的模型参数。
其中,所述将所述待训练图像输入所述生成对抗网络,获取所述生成对抗网络的损失值,包括:
将所述待训练图像分别输入所述生成器和所述鉴别器;
将所述生成器基于所述待训练图像生成的鉴别图像,输入所述鉴别器;
获取所述鉴别器训练所述待训练图像得到的第一损失值,和训练所述鉴别图像得到的第二损失值,并组成所述生成对抗网络的损失值。
本申请还提供了一种终端设备,所述终端设备包括:
获取模块,用于获取待训练图像;
训练模块,用于采用所述待训练图像训练多个阶段的老师模型;
所述训练模块,用于利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对所述待训练图像的识别结果;
调整模块,用于按照所述识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用所述学生模型的预测输出值和所述老师模型的预测输出值计算所述学生模型的损失值,以及按照所述损失值调整所述学生模型的模型参数。
本申请还提供了另一种终端设备,所述终端设备包括存储器和处理器,其中,所述存储器与所述处理器耦接;
其中,所述存储器用于存储程序数据,所述处理器用于执行所述程序数据以实现上述的模型训练方法。
本申请还提供了一种计算机存储介质,所述计算机存储介质用于存储程序数据,所述程序数据在被处理器执行时,用以实现上述的模型训练方法。
本申请的有益效果是:终端设备获取待训练图像;采用待训练图像训练多个阶段的老师模型;利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对待训练图像的识别结果;按照识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用学生模型的预测输出值和老师模型的预测输出值计算学生模型的损失值,以及按照损失值调整所述学生模型的模型参数。通过上述方式,本申请的模型训练方法采用多个阶段的老师模型分别对不同阶段的学生模型进行训练,从而使得学生模型在假扮老师模型的过程中不断接近甚至超过老师模型的性能。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。其中:
图1是本申请提供的模型训练方法一实施例的流程示意图;
图2是本申请提供的多阶段多生成对抗网络的在线知识蒸馏系统的框架示意图;
图3是本申请提供的模型训练方法另一实施例的流程示意图;
图4是本申请提供的模型训练方法又一实施例的流程示意图;
图5是本申请提供的生成-对抗网络的框架示意图;
图6是本申请提供的终端设备一实施例的结构示意图;
图7是本申请提供的终端设备另一实施例的结构示意图;
图8是本申请提供的计算机存储介质一实施例的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请的一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
深度神经网络在工业界和学术界都取得了巨大的成功,特别是在计算机视觉任务方面。深度学习的巨大成功主要是因为它可以对大规模数据进行编码,并且可以操纵数十亿个模型参数。然而,在资源有限的设备(如移动电话和嵌入式设备)上部署这些笨重的模型是一个挑战,不仅因为计算复杂度高,而且还因为存储需求大。为此,开发了各种模型压缩和加速技术。知识蒸馏作为模型压缩技术之一,其关键挑战是如何从老师模型中提取丰富而通用的知识,从而缩小学生模型和老师模型之间的性能差距。
对此,本申请设计了一种基于多阶段多生成对抗网络的在线知识蒸馏系统,实现了学生模型和老师模型在多阶段的协同学习,解决了强收敛的老师模型对学生模型在蒸馏过程中导致精度损失的问题,从而使得学生模型能够更好地紧随甚至超越老师的性能。
具体请参阅图1和图2,图1是本申请提供的模型训练方法一实施例的流程示意图,图2是本申请提供的多阶段多生成对抗网络的在线知识蒸馏系统的框架示意图。
其中,本申请的模型训练方法应用于一种终端设备,其中,本申请的终端设备可以为服务器,也可以为由服务器和电子设备相互配合的系统。相应地,终端设备包括的各个部分,例如各个单元、子单元、模块、子模块可以全部设置于服务器中,也可以分别设置于服务器和电子设备中。
进一步地,上述服务器可以是硬件,也可以是软件。当服务器为硬件时,可以实现成多个服务器组成的分布式服务器集群,也可以实现成单个服务器。当服务器为软件时,可以实现成多个软件或软件模块,例如用来提供分布式服务器的软件或软件模块,也可以实现成单个软件或软件模块,在此不做具体限定。在一些可能的实现方式中,本申请实施例的模型训练方法可以通过处理器调用存储器中存储的计算机可读指令的方式来实现。
具体而言,如图1所示,本申请实施例的模型训练方法具体包括以下步骤:
步骤S11:获取待训练图像。
在本申请实施例中,终端设备获取包括若干待训练图像组成的训练集,图像类型以及图像内容并不做限制。
进一步地,为了达到更好的模型训练效果,终端设备还可以将训练集中的若干待训练图像进行一系列的预处理,包括但不限于:对待训练图像的像素大小进行扩充、随机裁剪,和正则化等相关操作。终端设备采用以上处理操作对训练集进行批处理,划分为多个组,用于后续的模型训练。具体地,本申请实施例实施预处理的目的是为了降低待训练图像的干扰和噪声,以及适应模型处理图像的能力。
另外,终端设备还可以对本申请实施例中的老师模型和/或学生模型也进行模型预处理。具体地,不同的模型对于所处理的待训练图像有不同的需求。例如,对于ResNet模型(残差网络模型)而言,终端设备可以删除ResNet模型中的第一个最大池化层,并对相关的卷积层参数进行调整,使得ResNet模型对输入图像有着更好的泛化能力,从而减少不同训练集对ResNet模型性能的影响。对于不同类型的网络模型,可以采用不同的模型预处理方式,在此不再一一列举。
步骤S12:采用待训练图像训练多个阶段的老师模型。
在本申请实施例中,终端设备采用学生-老师学习机制实现对老师模型和学生模型的训练。下面先介绍学生-老师学习机制的内容:
为了更好地说明,假设WT表示老师模型的模型参数,WS表示学生模型的模型参数。PT=Softmax(ZT)表示老师模型的预测输出值,PS=Softmax(ZS)表示学生模型的预测输出值,ZT表示老师模型的logits值(预测概率值),ZS表示学生模型的logits值。
KD(Knowledge Distillation,知识蒸馏)的思想是将老师模型的输出作为软目标来指导学生模型的训练。同时,用标签y计算学生模型的输出,具体地,计算KD的损失值如下:
Figure BDA0003291218330000081
其中,LCE表示交叉熵损失值,y表示标签的一个热向量,t表示温度超参数,T表示蒸馏温度值,δ表示一个权衡超参数。上述式子(1)中,第一项是使用标签y定义的交叉熵损失值,第二项是鼓励学生模型模仿老师模型的类别软化分数。
通常来说,老师模型调用一个大而复杂的模型来达到较低的局部最优值,但这种蒸馏方法存在一些问题。首先,由于老师模型和学生模型之间存在着巨大的性能差距,一个小而浅的学生模型很难模仿甚至超越老师模型。其次,模型训练通常采用随机梯度下降法促使损失函数最小化。由于损失函数的高非凸性,在训练过程中会出现许多局部最优值。当网络收敛到某个局部最小值时,无论初始化方式如何,其训练损失都会收敛到某个值,或类似的值。
因此,本申请实施例采用多阶段的方式实现学生-老师学习机制,即应该对学生模型分阶段进行训练。在对学生模型分阶段训练之前,终端设备需要将老师模型的训练分为几个阶段,得到每个阶段的训练完成的老师模型来指导学生模型。
具体地,终端设备在一个阶段的训练中,初始化老师模型的模型参数,将待训练图像输入老师模型;按照第一预设迭代次数对老师模型进行训练后,冻结老师模型在这一阶段的模型参数;在下一个阶段的训练中,重新初始化老师模型的模型参数,将待训练图像输入老师模型;按照第二预设迭代次数对老师模型进行训练后,冻结老师模型在这一阶段的模型参数。其中,第一预设迭代次数和第二预设迭代次数可以一致,也可以不同,由阶段性的训练计划决定。
例如,待训练图像作为老师模型的输入,从老师模型中得到输出结果,以标签y计算交叉熵损失,最后对老师模型进行反向传播。假设老师模型的训练分为N个阶段,第i个阶段记为Stagei。终端设备对老师模型的模型参数进行随机初始化,并在每个阶段降低学习率,以更好地达到较低的局部最优值。因此,在阶段Stagei中老师模型的训练损失值可以表示为:
Figure BDA0003291218330000091
经过阶段Stagei的训练后,被训练的老师模型Ti的模型参数被冻结,并参与到后续相同阶段的学生模型的训练过程中。通过这种分阶段的训练方法,鼓励学生模型更容易、更快地模仿甚至超越老师模型。
步骤S13:利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对待训练图像的识别结果。
在本申请实施例中,终端设备利用步骤S12训练得到的各个阶段的老师模型分别对相同阶段的学生模型进行训练,如图2所示,分别得到对于待训练图像的识别结果。
步骤S14:按照识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用学生模型的预测输出值和老师模型的预测输出值计算学生模型的损失值,以及按照损失值调整学生模型的模型参数。
在本申请实施例中,终端设备按照识别结果获取老师模型的预测输出值和学生模型的预测输出值,并如图2所示,利用学生模型的预测输出值和老师模型的预测输出值计算学生模型的均方误差损失值(MSE,Mean Square Error),以及按照均方误差损失值调整学生模型的模型参数。通过不停迭代,使得迭代过程中均方误差损失值不断变小,直至小于预设损失值。
进一步地,如图2所示,终端设备还可以利用学生模型的预测输出值(output)和预设的目标输出值(Hard target),计算学生模型的交叉熵损失值(CE,Cross EntropyLoss),该交叉熵损失值表征了学生模型的预测输出值与目标输出值之间的差距。
终端设备将上述均方误差损失值和交叉熵损失值组成得到学生模型的损失值,通过不停迭代,使得迭代过程中学生模型的损失值不断变小,直至小于预设损失值。训练完成后的学生模型一方面与老师模型的差距变小,另一方面学生模型的识别准确率有一定的提高。
在本申请实施例中,终端设备获取待训练图像;采用待训练图像训练多个阶段的老师模型;利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对待训练图像的识别结果;按照识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用学生模型的预测输出值和老师模型的预测输出值计算学生模型的损失值,以及按照损失值调整所述学生模型的模型参数。通过上述方式,本申请的模型训练方法采用多个阶段的老师模型分别对不同阶段的学生模型进行训练,从而使得学生模型在假扮老师模型的过程中不断接近甚至超过老师模型的性能。
请继续参阅图3,图3是本申请提供的模型训练方法另一实施例的流程示意图。
具体而言,如图3所示,本申请实施例的模型训练方法具体包括以下步骤:
步骤S21:将同一阶段的老师模型和学生模型按照相同划分方式将网络模型划分为相同数量的卷积块。
传统的知识蒸馏方法直接将老师模型的输出作为软目标,以最大限度地减少学生模型输出的损失。然而,老师模型所传递的知识不一定对学生模型有帮助。最好的情况是,学生模型能够正确地学习最重要的细节,而省略不必要的细节,这不会影响他们在特定任务中的表现。
因此,在本申请实施例中,终端设备采用逐层贪婪训练的方式在学生模型训练中,将学生模型的训练过程以卷积块的形式划分为M个卷积块,例如,第i个老师模型对学生模型中第j个卷积块的训练称为Stagei,j
步骤S22:利用老师模型相应卷积块的预测输出值对学生模型的每一个卷积块进行训练,获取学生模型中目标卷积块的预测输出值,以及老师模型中相应卷积块的预测输出值。
在本申请实施例中,与直接的端到端的训练模式不同,终端设备对学生模型的训练也要分阶段进行,即一次一个卷积块训练。
具体地,待训练图像分别作为老师模型和学生模型的输入,第一个卷积块,即老师模型的参考卷积块和学生模型的目标卷积块的输出从这两个模型中取出,即获取学生模型中目标卷积块的预测输出值,以及老师模型中相应卷积块的预测输出值。
步骤S23:利用目标卷积块的预测输出值和相应卷积块的预测输出值计算目标卷积块的损失值,以及按照损失值调整目标卷积块的模型参数。
在本申请实施例中,终端设备利用目标卷积块的预测输出值和参考卷积块的预测输出值计算目标卷积块的均方误差损失值,即图2中的两个卷积块之间的MSE,然后对学生模型进行反向传播。
当前阶段的模型训练结束后,在下一个阶段的模型训练中,再次将待训练图像分别输入老师模型和学生模型,但输出从下一个目标卷积块和下一个参考卷积块取出。终端设备遵循以上步骤,对老师模型和学生模型的所有卷积块都重复训练。其中,对于老师模型Ti指导学生模型Si中间层,即目标卷积块的训练损失值如下:
Figure BDA0003291218330000111
其中,
Figure BDA0003291218330000112
表示老师模型Ti第j个卷积块的预测输出值,
Figure BDA0003291218330000113
表示学生模型Si第j个卷积块的预测输出值。
通过以上逐层贪婪训练,学生模型可以逐渐学习到老师模型的一些重要细节,使得学生模型能够更好地模仿老师模型。
但是,由于不同的模型之间存在着一定的差异,模式的训练方式也不尽相同。逐层贪婪策略直接计算中间层结果的损失,忽略了不同模型之间实例转换方式的差异以及标签对模型训练的重要性。为此,本申请引入了多生成对抗网络来解决这一问题。
请继续参阅图4和图5,图4是本申请提供的模型训练方法又一实施例的流程示意图,图5是本申请提供的生成-对抗网络的框架示意图。
具体而言,如图4所示,本申请实施例的模型训练方法具体包括以下步骤:
步骤S31:将学生模型中的目标卷积块作为生成器。
步骤S32:将老师模型中的参考卷积块后面的卷积块作为鉴别器。
步骤S33:利用生成器和鉴别器组成生成对抗网络。
在本申请实施例中,考虑到GAN(生成对抗网络)的收敛困难和梯度消失的问题,本申请没有引入额外的生成器和鉴别器,而是在模型内部使用一定数量的卷积块来替代它们。其中,GAN中G是generator,生成器:负责凭空捏造数据出来,D是discriminator,判别器:负责判断数据是不是真数据。
假设学生模型的前b个卷积块(1,2,...,b)作为生成器(b≥0),对于老师模型后面的剩下卷积块(b+1,...,M)作为鉴别器。本申请实施例通过在多阶段训练模式中引入多生成网络和对抗性网络,提高了训练学生模型的训练效率。
步骤S34:将待训练图像输入生成对抗网络,获取生成对抗网络的损失值。
步骤S35:利用生成对抗网络的损失值,调整目标卷积块的模型参数。
在本申请实施例中,首先从任意香草GAN开始,GAN的经典公式原理是在生成器Gx和判别器Dx之间的极大极小博弈,目标函数可以表示为:
Figure BDA0003291218330000121
在本申请中,由于真实标签的输入,传统的GAN训练是不可行的。因此,本申请实施例提出了一种生成-对抗网络的整体培训框架。
对于多生成网络,首先对于生成器的定义,本申请设计的生成器不同于其他直接引入生成器模型的方法,本申请利用了多阶段训练的优势,将学生模型的前b个卷积块定义为特征生成器。假设M为生成器的总组数,老师模型和学生模型的卷积块组数相同。
因此,生成器可以用{G1,G2,G3,...,GM}表示,则输入待训练图像I和第j组生成器的输出用
Figure BDA0003291218330000131
表示:
Figure BDA0003291218330000132
Figure BDA0003291218330000133
当j=B时,即学生模型的最后一个卷积块的输出,也就等价于
Figure BDA0003291218330000134
对于多对抗网络,由于生成器由学生模型的多个卷积块组成,因此,相应地构造了由老师模型的多个卷积块组成的鉴别器。假设M也是鉴别器的总组数,鉴别器可以用{D1,D2,D3,...,DM}表示,得到满足的Gj*
Figure BDA0003291218330000135
其中,Dj*表示第j个卷积块的最优鉴别器。用传统方法训练鉴别器也是不可能的,因此,本申请实施例将这个最大和最小博弈转化为生成样本和真实样本之间差异的最小化。这样,在训练第.j组时,只优化了Gj,而{Dj+1,Dj+2,...,DM}是固定的,Dj的输出输入到{Dj +1,Dj+2,...,DM},得到它的“假”输出用于分类标签。
需要说明的是,对于第j个卷积块的训练,生成器指的是学生模型{G1,G2,...,Gj}定义为G1,j,而鉴别器对应老师模型{Dj+1,...,DM}定义为Dj+1,M
因此,对于生成器G1,j,得到改进的多生成对抗网络的训练损失值为:
Figure BDA0003291218330000136
计算通过
Figure BDA0003291218330000137
得到的“假”结果和D(I)得到的“真”结果进行损失最小化,鼓励学生模型模仿老师模型的实例转换过程。另一方面,最小化
Figure BDA0003291218330000141
和标签y的损失,以摆脱老师模型的限制,从而可能超越老师模型的表现。
进一步地,通过LGAN完成对学生模型的多个生成器的训练后,得到由多个生成器组成的目标网络。在训练学生模型的最后一个块DM时,我们只计算了学生-老师中间层损失值和真实标签损失值。最终的损失函数是通过对所有损失进行整合得到的:
Loss=LMid+αLgan+βLGT+γLIRG-t (9)
其中,α,β,γ分别表示权衡超参数,LMid表示{Dj,Gj}组输出的均方误差损失值,Lgan表示计算多生成对抗网络的最小损失值,LGT表示学生模型与标签数据之间的计算损失值。
需要说明的是,本申请只需要使用LMid和LGT来计算学生模型的最后一个块DM的损失值,另外,考虑到实例空间的信息,引入了LMTK,但由于计算代价高,可以考虑选择LIRG-t替代进行最小损失计算。
本申请实施例结合在线蒸馏和逐层贪婪训练方法,提出了一种多阶段学习策略,解决了强收敛性的老师模型导致的蒸馏精度损失增大的问题;针对传统生成式对抗网络模型难以训练和梯度消失的问题,采用生成式对抗策略与逐层贪婪训练训练相结合的方法,实现了一种多阶段多生成式对抗网络模型;引入中间层损失、空间转换损失、软目标损失和硬目标损失,实现了模型更好的在线蒸馏过程。在CIFAR10/100、ImageNette/ImageWoof数据集上测试的性能优于其他先进的知识蒸馏方法。
本领域技术人员可以理解,在具体实施方式的上述方法中,各步骤的撰写顺序并不意味着严格的执行顺序而对实施过程构成任何限定,各步骤的具体执行顺序应当以其功能和可能的内在逻辑确定。
为实现上述实施例的模型训练方法,本申请还提出了一种终端设备,具体请参阅图6,图6是本申请提供的终端设备一实施例的结构示意图。
如图6所示,本申请提供的终端设备400包括获取模块41、训练模块42以及调整模块43。
其中,获取模块41,用于获取待训练图像。
训练模块42,用于采用所述待训练图像训练多个阶段的老师模型。
训练模块42,用于利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对所述待训练图像的识别结果。
调整模块43,用于按照所述识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用所述学生模型的预测输出值和所述老师模型的预测输出值计算所述学生模型的损失值,以及按照所述损失值调整所述学生模型的模型参数。
为实现上述实施例的模型训练方法,本申请还提出了另一种终端设备,具体请参阅图7,图7是本申请提供的终端设备另一实施例的结构示意图。
本申请实施例的终端设备500包括存储器51和处理器52,其中,存储器51和处理器52耦接。
存储器51用于存储程序数据,处理器52用于执行程序数据以实现上述实施例所述的模型训练方法。
在本实施例中,处理器52还可以称为CPU(Central Processing Unit,中央处理单元)。处理器52可能是一种集成电路芯片,具有信号的处理能力。处理器52还可以是通用处理器、数字信号处理器(DSP,Digital Signal Process)、专用集成电路(ASIC,ApplicationSpecific Integrated Circuit)、现场可编程门阵列(FPGA,Field Programmable GateArray)或者其它可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。通用处理器可以是微处理器或者该处理器52也可以是任何常规的处理器等。
本申请还提供一种计算机存储介质,如图8所示,计算机存储介质800用于存储程序数据61,程序数据61在被处理器执行时,用以实现如上述实施例所述的模型训练方法。
本申请还提供一种计算机程序产品,其中,上述计算机程序产品包括计算机程序,上述计算机程序可操作来使计算机执行如本申请实施例所述的模型训练方法。该计算机程序产品可以为一个软件安装包。
本申请上述实施例所述的模型训练方法,在实现时以软件功能单元的形式存在并作为独立的产品销售或使用时,可以存储在装置中,例如一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本发明各个实施方式所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述仅为本申请的实施方式,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。

Claims (10)

1.一种模型训练方法,其特征在于,所述模型训练方法包括:
获取待训练图像;
采用所述待训练图像训练多个阶段的老师模型;
利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对所述待训练图像的识别结果;
按照所述识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用所述学生模型的预测输出值和所述老师模型的预测输出值计算所述学生模型的损失值,以及按照所述损失值调整所述学生模型的模型参数。
2.根据权利要求1所述的模型训练方法,其特征在于,
所述采用所述待训练图像训练多个阶段的老师模型,包括:
在一个阶段的训练中,初始化所述老师模型的模型参数,将所述待训练图像输入所述老师模型;
按照第一预设迭代次数对所述老师模型进行训练后,冻结所述老师模型在这一阶段的模型参数;
在下一个阶段的训练中,重新初始化所述老师模型的模型参数,将所述待训练图像输入所述老师模型;
按照第二预设迭代次数对所述老师模型进行训练后,冻结所述老师模型在这一阶段的模型参数。
3.根据权利要求1所述的模型训练方法,其特征在于,
所述利用所述学生模型的预测输出值和所述老师模型的预测输出值计算所述学生模型的损失值,包括:
利用所述学生模型的预测输出值和所述老师模型的预测输出值,计算所述学生模型的均方误差损失值;
利用所述学生模型的预测输出值和预设的目标输出值,计算所述学生模型的交叉熵损失值;
将所述均方误差损失值和所述交叉熵损失值组成得到所述学生模型的损失值。
4.根据权利要求1所述的模型训练方法,其特征在于,
所述模型训练方法,还包括:
将同一阶段的老师模型和学生模型按照相同划分方式将网络模型划分为相同数量的卷积块;
利用所述老师模型相应卷积块的预测输出值对所述学生模型的每一个卷积块进行训练,获取所述学生模型中目标卷积块的预测输出值,以及所述老师模型中相应卷积块的预测输出值;
利用所述目标卷积块的预测输出值和所述相应卷积块的预测输出值计算所述目标卷积块的损失值,以及按照所述损失值调整所述目标卷积块的模型参数。
5.根据权利要求4所述的模型训练方法,其特征在于,
所述模型训练方法,还包括:
将所述待训练图像分别输入所述老师模型和所述学生模型,获取所述老师模型中参考卷积块的预测输出值和所述学生模型中目标卷积块的预测输出值,其中,所述参考卷积块和所述目标卷积块在网络模型中的位置相同;
利用所述目标卷积块的预测输出值和所述参考卷积块的预测输出值计算所述目标卷积块的均方误差损失值,以及按照所述均方误差损失值调整所述目标卷积块的模型参数。
6.根据权利要求4所述的模型训练方法,其特征在于,
所述模型训练方法,还包括:
将所述学生模型中的目标卷积块作为生成器;
将所述老师模型中的参考卷积块后面的卷积块作为鉴别器;
利用所述生成器和所述鉴别器组成生成对抗网络;
将所述待训练图像输入所述生成对抗网络,获取所述生成对抗网络的损失值;
利用所述生成对抗网络的损失值,调整所述目标卷积块的模型参数。
7.根据权利要求6所述的模型训练方法,其特征在于,
所述将所述待训练图像输入所述生成对抗网络,获取所述生成对抗网络的损失值,包括:
将所述待训练图像分别输入所述生成器和所述鉴别器;
将所述生成器基于所述待训练图像生成的鉴别图像,输入所述鉴别器;
获取所述鉴别器训练所述待训练图像得到的第一损失值,和训练所述鉴别图像得到的第二损失值,并组成所述生成对抗网络的损失值。
8.一种终端设备,其特征在于,所述终端设备包括:
获取模块,用于获取待训练图像;
训练模块,用于采用所述待训练图像训练多个阶段的老师模型;
所述训练模块,用于利用不同阶段的老师模型分别对学生模型进行训练,获取每个阶段的老师模型和学生模型对所述待训练图像的识别结果;
调整模块,用于按照所述识别结果获取老师模型的预测输出值和学生模型的预测输出值,并利用所述学生模型的预测输出值和所述老师模型的预测输出值计算所述学生模型的损失值,以及按照所述损失值调整所述学生模型的模型参数。
9.一种终端设备,其特征在于,所述终端设备包括存储器和处理器,其中,所述存储器与所述处理器耦接;
其中,所述存储器用于存储程序数据,所述处理器用于执行所述程序数据以实现权利要求1-7任一项所述的模型训练方法。
10.一种计算机存储介质,其特征在于,所述计算机存储介质用于存储程序数据,所述程序数据在被处理器执行时,用以实现权利要求1-7任一项所述的模型训练方法。
CN202111164452.XA 2021-09-30 一种模型训练方法、终端设备及计算机存储介质 Active CN114065834B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111164452.XA CN114065834B (zh) 2021-09-30 一种模型训练方法、终端设备及计算机存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111164452.XA CN114065834B (zh) 2021-09-30 一种模型训练方法、终端设备及计算机存储介质

Publications (2)

Publication Number Publication Date
CN114065834A true CN114065834A (zh) 2022-02-18
CN114065834B CN114065834B (zh) 2024-07-02

Family

ID=

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116681790A (zh) * 2023-07-18 2023-09-01 脉得智能科技(无锡)有限公司 一种超声造影图像生成模型的训练方法及图像的生成方法

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110147456A (zh) * 2019-04-12 2019-08-20 中国科学院深圳先进技术研究院 一种图像分类方法、装置、可读存储介质及终端设备
CN112465138A (zh) * 2020-11-20 2021-03-09 平安科技(深圳)有限公司 模型蒸馏方法、装置、存储介质及设备
CN112527127A (zh) * 2020-12-23 2021-03-19 北京百度网讯科技有限公司 输入法长句预测模型的训练方法、装置、电子设备及介质
CN113222123A (zh) * 2021-06-15 2021-08-06 深圳市商汤科技有限公司 模型训练方法、装置、设备及计算机存储介质

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110147456A (zh) * 2019-04-12 2019-08-20 中国科学院深圳先进技术研究院 一种图像分类方法、装置、可读存储介质及终端设备
CN112465138A (zh) * 2020-11-20 2021-03-09 平安科技(深圳)有限公司 模型蒸馏方法、装置、存储介质及设备
CN112527127A (zh) * 2020-12-23 2021-03-19 北京百度网讯科技有限公司 输入法长句预测模型的训练方法、装置、电子设备及介质
CN113222123A (zh) * 2021-06-15 2021-08-06 深圳市商汤科技有限公司 模型训练方法、装置、设备及计算机存储介质

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
刘尚争;刘斌;: "生成对抗网络图像类别标签跨模态识别系统设计", 现代电子技术, no. 08, 15 April 2020 (2020-04-15) *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116681790A (zh) * 2023-07-18 2023-09-01 脉得智能科技(无锡)有限公司 一种超声造影图像生成模型的训练方法及图像的生成方法
CN116681790B (zh) * 2023-07-18 2024-03-22 脉得智能科技(无锡)有限公司 一种超声造影图像生成模型的训练方法及图像的生成方法

Similar Documents

Publication Publication Date Title
CN110263912B (zh) 一种基于多目标关联深度推理的图像问答方法
CN108875807B (zh) 一种基于多注意力多尺度的图像描述方法
CN109891897B (zh) 用于分析媒体内容的方法
US20180018555A1 (en) System and method for building artificial neural network architectures
CN109558576B (zh) 一种基于自注意力机制的标点符号预测方法
CN110347873A (zh) 视频分类方法、装置、电子设备及存储介质
CN103049792A (zh) 深层神经网络的辨别预训练
CN112580694B (zh) 基于联合注意力机制的小样本图像目标识别方法及系统
CN113516133B (zh) 一种多模态图像分类方法及系统
CN112417752B (zh) 基于卷积lstm神经网络的云层轨迹预测方法及系统
Dai et al. Hybrid deep model for human behavior understanding on industrial internet of video things
CN110930996A (zh) 模型训练方法、语音识别方法、装置、存储介质及设备
Du et al. Efficient network construction through structural plasticity
CN115908641A (zh) 一种基于特征的文本到图像生成方法、装置及介质
CN111626404A (zh) 基于生成对抗神经网络的深度网络模型压缩训练方法
Milutinovic et al. End-to-end training of differentiable pipelines across machine learning frameworks
CN113554040B (zh) 一种基于条件生成对抗网络的图像描述方法、装置设备
Qi et al. Learning low resource consumption cnn through pruning and quantization
CN117634459A (zh) 目标内容生成及模型训练方法、装置、系统、设备及介质
CN111783688B (zh) 一种基于卷积神经网络的遥感图像场景分类方法
CN114065834A (zh) 一种模型训练方法、终端设备及计算机存储介质
CN114065834B (zh) 一种模型训练方法、终端设备及计算机存储介质
CN116229323A (zh) 一种基于改进的深度残差网络的人体行为识别方法
CN115063374A (zh) 模型训练、人脸图像质量评分方法、电子设备及存储介质
CN114861917A (zh) 贝叶斯小样本学习的知识图谱推理模型、系统及推理方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant