CN114519416A - 模型蒸馏方法、装置及电子设备 - Google Patents
模型蒸馏方法、装置及电子设备 Download PDFInfo
- Publication number
- CN114519416A CN114519416A CN202111680113.7A CN202111680113A CN114519416A CN 114519416 A CN114519416 A CN 114519416A CN 202111680113 A CN202111680113 A CN 202111680113A CN 114519416 A CN114519416 A CN 114519416A
- Authority
- CN
- China
- Prior art keywords
- label
- ith
- data
- recognition result
- probability
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明实施例涉及一种模型蒸馏方法、装置及电子设备,该方法包括:对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果;根据一组第一识别结果,获取第二识别结果;当至少一个备选标签中第一备选标签的预测概率大于或者等于预设概率阈值时,将第一备选标签设定为未标记数据的伪标签;对第i个未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果;根据第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束。通过该方式,大大降低人力标记的成本。
Description
技术领域
本发明实施例涉及计算机技术领域,尤其涉及一种模型蒸馏方法、装置及电子设备。
背景技术
当深度神经网络通过监督学习实现其强大性能的时候,常常需要一个标记的数据集(也即是已知类别的样本数据集)。因为标记数据通常需要人力完成,所以数据集的标记则需要消耗一定的人力成本。数据集越大,其所带来的性能优势更强,随之而来的就是巨大的数据集可能会带来巨大的人力成本,而当标记数据的工作必须由专家完成时,这种成本可能会更高。
模型蒸馏是一种常用的模型压缩方法,首先训练一个大的教师模型,然后使用教师模型输出的预测值训练小的学生模型。学生模型学习教师模型的预测结果(概率值)从而学习到教师模型的泛化性能。
目前的蒸馏方法也是在监督数据中进行的,自然也需要大量已被标记的数据集,因而导致模型蒸馏方法的人力成本随之增加。
发明内容
本申请提供了一种模型蒸馏方法、装置及电子设备,以解决现有技术中模型蒸馏方法使用大量被标记数据集,导致模型蒸馏方法中的人力成本过高的技术问题。
第一方面,本申请提供了一种模型蒸馏方法,该方法包括:
对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果;
根据一组第一识别结果,获取第二识别结果,其中第二识别结果包括第i个未标记数据的至少一个备选标签,以及每一个备选标签对应的预测概率;
当至少一个备选标签中第一备选标签的预测概率大于或者等于预设概率阈值时,将第一备选标签设定为未标记数据的伪标签,其中,第一备选标签为至少一个备选标签中概率最大的标签;
对第i个未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果;
根据第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束,其中,i为正整数。
在一种可选的实施方式中,对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果之前,方法还包括:
利用已标记数据分别对至少两个教师网络模型中的每一个教师网络模型进行训练。
在一种可选的实施方式中,确定学生网络模型是否符合预设标准,具体包括:
根据第一损失函数和第二损失函数,确定目标损失;
当目标损失达到预设损失阈值时,确定学生网络模型符合预设标准。
第二方面,本申请提供了一种多模型蒸馏装置,该装置包括:
处理模块,用于对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果;根据一组第一识别结果,获取第二识别结果,其中第二识别结果包括第i个未标记数据的至少一个备选标签,以及每一个备选标签对应的预测概率;
设定模块,用于当至少一个备选标签中第一备选标签的预测概率大于或者等于预设概率阈值时,将第一备选标签设定为未标记数据的伪标签,其中,第一备选标签为至少一个备选标签中概率最大的标签;
处理模块,还用于对第i个未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果;
优化模块,根据第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束,其中,i为正整数。
在一种可选的实施方式中,处理模块,还用于利用已标记数据分别对至少两个教师网络模型中的每一个教师网络模型进行训练。
在一种可选的实施方式中,优化模块,具体用于根据第一损失函数和第二损失函数,确定目标损失;
当目标损失达到预设损失阈值时,确定学生网络模型符合预设标准。
第三方面,提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现第一方面任一项实施例的模型蒸馏方法的步骤。
第四方面,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现如第一方面任一项实施例的模型蒸馏方法的步骤。
本申请实施例提供的上述技术方案与现有技术相比具有如下优点:
本申请实施例提供的该方法,对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,其中教师网络模型为已经训练好的教师网络模型。得到一组第一识别结果,然后根据一组第一识别结果,获取第二识别结果。在第二识别结果中,包括第i个未标记数据的至少一个备选标签,以及每一个备选标签对应的预测概率。当至少一个备选标签中的第一备选标签的预测概率大于或者等于预设概率阈值时,则可以认定第一备选标签为与第i个未标记数据对应的伪标签。然后,对第i个未标记数据进行第二泛化处理后,输入到学生网络模型,获取第三识别结果,将第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束。该过程中,考虑到同一个未标记数据对应的标签具有唯一性,且因为教师网络模型为训练好的网络模型,其最终得出的训练结果具有一定的有效性。因此,可以利用第二识别结果和伪标签,以及第三识别结果共同作为参考依据,对学生网络模型进行修正,重复执行上述过程,直至学生网络模型达到预设标准结束。因为对学生网络模型的训练是采用未标记数据,不涉及已标记数据,因此大大降低了人力标记的成本,而且同样可以达到与利用已标记数据训练学生网络模型相同或类似的效果。
附图说明
图1为本发明提供的一种训练教师模型的方法流程示意图;
图2为本发明实施例提供的一种模型蒸馏方法流程示意图;
图3为本发明实施例提供的另一种模型蒸馏方法流程示意图;
图4为本发明实施例提供的另一种模型蒸馏方法流程示意图;
图5为本发明实施例提供的另一种模型蒸馏方法流程示意图;
图6为本发明提供的一种模型蒸馏方法的一个具体应用实例的方法流程示意图;
图7为本发明实施例提供的一种多模型蒸馏装置结构示意图;
图8为本发明实施例提供一种电子设备结构示意图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
为便于对本发明实施例的理解,下面将结合附图以具体实施例做进一步的解释说明,实施例并不构成对本发明实施例的限定。
针对背景技术中所提及的技术问题,本申请实施例提供了一种模型蒸馏方法。
具体的,在多模型蒸馏方法中,首先需要训练一个大的教师模型,然后使用教师模型输出的预测值训练小的学生模型。因此,在执行本方法步骤之前,首先需要训练教师模型,然后才会使用已经训练好的教师模型,训练学生模型。也即是在介绍本发明实施例的方法步骤之前,首先说明执行本发明实施例方法步骤之前的一些准备工作。
具体方法步骤参见图1所示,该方法包括:
步骤110,收集样本数据。
具体的,以分类任务为例,根据任务需求,确定分类类别,收集不同场景的数据,对数据进行预处理,并将预处理后的数据分配到不同的类别。然后,按照一定的比例,将所有收集到的数据划分为训练集、验证集和测试集,用以方便后续对教师网络模型进行训练、验证和测试等操作。
步骤120,利用已标记数据分别对至少两个教师网络模型中的每一个教师网络模型进行训练。
具体的,对于教师网络模型的训练,可以采用多种方式,以增加最终得出的教师网络训练模型多样化,不同教师网络模型得到的结果可以互补。
在一个具体的例子中,例如按照数据不同,将训练数据分成多个彼此不同的训练集(M1,M2,……,Mn),分别在每个训练集上训练各自独立的教师网络,最终就可以得到多个教师网络模型。
在另一个具体的例子中,例如按照网络结构不同,训练得到多个不同的教师网络模型,网络结构例如可以包括但不限于如下中的一种或多种: resnet,googlenet,mobilenet等等。通过网络结构的不同,训练多个教师网络模型。
而最终在本方法实施例中应用的教师网络模型,则可以选择上述所介绍的任一种或多种训练方式所得到的多个教师网络模型。
优选的,考虑到本实施例中期望学生网络模型的预测结果能够更接近实际,所以综合选择上述两种方式中的多个教师网络模型。以使得多个教师网络模型可以从多个维度来对未标记数据进行预测,效果相较于使用单个教师网络模型,或者是单纯某种方式获取的教师网络模型而言,更加的全面,不同教师网络模型所处结果可以互补。
在获取到经过训练好的教师网络模型后,执行如下方法步骤,参见图 2所示,图2为本发明实施例提供的一种模型蒸馏方法流程示意图。该方法步骤包括:
步骤210,对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果。
具体的,对第i个未标记数据进行泛化处理,目的是为了增加未标记数据的泛化性能。其主要考虑到因为数据为未标记数据,为了增加模型的识别能力,所以需要执行泛化处理。需要说明的是,在进入到不同的教师网络模型之前,对第i个未标记的数据执行的泛化处理可以不同,因此在经过教师网络模型预测后,所得到的结果在数值上也可能稍有差异。一个教师网络模型输出一个第一识别结果,所以最终获取的是一组第一识别结果。而,一组第一识别结果中每一个识别结果所包括的内容类型相同或相应(通常包括识别标签以及每个标签对应的概率),但是有可能(对应标签的概率)数值上可能稍有不同。
步骤220,根据一组第一识别结果,获取第二识别结果。
具体的,在获取到不同教师网络模型的不同第一识别结果后,可以对不同第一识别结果进行集成,而得到第二识别结果。对于将一组第一识别结果进行集成,其目的也是如上文所介绍的,希望能够综合多个不同教师网络模型的预测结果,使其最终的预测结果能够更贴近实际情况。
在一个具体的例子中,第二识别结果包括第i个未标记数据的至少一个备选标签,以及每一个备选标签对应的预测概率。
步骤230,当至少一个备选标签中第一备选标签的预测概率大于或者等于预设概率阈值时,将第一备选标签设定为未标记数据的伪标签。
具体的,如上所介绍的,将多个第一识别结果进行集成,就是为了使得最终所获取的结果能够更加贴近实际。而第二识别结果中,实际上是包括至少一个备选标签,以及每一个备选标签所对应的预测概率的。
根据每一个备选标签对应的预测概率,从中选取预测概率最大的备选标签,作为第一备选标签。并确定第一备选标签的预测概率是否大于或者等于预设概率阈值。
如果第一备选标签的预测概率大于或者等于预设概率阈值时,则可以设定第一备选标签为未标记数据的伪标签。
在一个可选的例子中,如果第一备选标签的预测概率小于预设概率阈值,则可以舍弃掉该未标记数据。i自增1,重新获取下一个未标记数据,然后重新执行上述所有操作,直至确定某一个未标记的数据的第一备选标签的预测概率大于或者等于预设概率阈值。
之所以执行此操作,是考虑到第一备选标签的预测概率既然都达不到预设概率阈值,(也即是对该数据的预测标签可能与该数据的实际标签之间的差距太大)那么该标签对于学生网络模型的优化也不会带来太大意义,所以直接舍弃掉。
然后再执行下文所介绍的操作流程。
步骤240,对第i个未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果。
具体的,同上文所介绍的,对于未标记数据进行第二泛化处理,同样是为了增加未标记数据的泛化性能。然后将经过第二泛化处理后的第i个未标记数据输入到学生网络模型中,得到第三识别结果。
步骤250,根据第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束。
具体的,第二识别结果是集成多个教师网络模型的输出结果的成果,第三识别结果则是学生网络模型的输出结果,根据一致性原则,不同网络输出模型针对同一个未标记数据的预测,目标结果应该一致。但是考虑到学生网络是未经过已标记数据训练过的网络模型,所以可以借用通过已标记数据的教师网络预测结果对学生网路进行修正。而之所以还借用伪标签,同样是考虑到伪标签是多个教师网络模型输出结果的集成后的“成果”,学生网络输出的第三识别结果中的目标标签和伪标签之间的差距,同样可以作为一种“负反馈”对学生网络模型进行优化,使得学生网络模型可以达到更优。
本发明实施例提供的模型蒸馏方法,对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,其中教师网络模型为已经训练好的教师网络模型。得到一组第一识别结果,然后根据一组第一识别结果,获取第二识别结果。在第二识别结果中,包括第i个未标记数据的至少一个备选标签,以及每一个备选标签对应的预测概率。当至少一个备选标签中的第一备选标签的预测概率大于或者等于预设概率阈值时,则可以认定第一备选标签为与第i个未标记数据对应的伪标签。然后,对第i 个未标记数据进行第二泛化处理后,输入到学生网络模型,获取第三识别结果,将第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束。该过程中,考虑到同一个未标记数据对应的标签具有唯一性,且因为教师网络模型为训练好的网络模型,其最终得出的训练结果具有一定的有效性。因此,可以利用第二识别结果和伪标签,以及第三识别结果共同作为参考依据,对学生网络模型进行修正,重复执行上述过程,直至学生网络模型达到预设标准结束。因为对学生网络模型的训练是采用未标记数据,不涉及已标记数据,因此大大降低了人力标记的成本,而且同样可以达到与利用已标记数据训练学生网络模型相同或类似的效果。
可选的,在上述实施例的基础上,本发明实施例提供的另一种模型蒸馏方法,与上述实施例相同或相应的内容这里不再过多介绍,下文中侧重介绍“对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果”的具体执行过程,具体参见图3 所示,该方法步骤包括:
步骤310,分别对第i个未标记数据执行不同类型的弱增广处理。
具体的,在图像的深度学习中,为了丰富图像训练集,提高模型的泛化能力,一般会对图像进行数据增强。也即是常用的增广方法,其中,弱增广的方式例如:旋转、改变图像色差、扭曲图像特征、增加图像噪声(高斯噪声、盐胶噪声)等等。
步骤320,将经过不同类型的弱增广处理后的未标记数据,对应输入到不同的教师网络模型中,获取一组第一识别结果。
通过不同的弱增广方式,可以对第i个未标记数据执行不同的处理,自然经过不同处理后的数据(图像)输入到不同的教师网络模型中,也可以得到不同的预测结果。
进行弱增广处理的目的,如上所介绍的,是为了增加未标记数据的泛化能力,以增强模型的识别能力。同样的数据经过不同的增广方式,不会改变它的类型,但却可以增强模型对输入数据发生变化有更强的适应能力。因此,采用不同的弱增广方式对未标记数据进行处理。
在一个可选的例子中,模型蒸馏方法,主要目的就是将多个模型压缩成一个模型。那么,多个教师模型的输出结果,也需要进行集成,得到多个教师网络模型集成后的输出结果,也即是第二识别结果。因此,下文中,将详细说明根据一组第一识别结果,获取第二识别结果的详细过程。在介绍该方法步骤之前,首先介绍第一识别结果所包括的内容。具体的,第一识别结果可以包括:与第i个未标记数据对应的至少一个备选标签,以及每一个备选标签对应的第一候选概率。
而根据一组第一识别结果,获取第二识别结果,具体可以包括:
分别获取一组第一识别结果中相同备选标签对应的第一候选概率的第一概率平均值。
将第一识别结果中的至少一个备选标签,以及每一个备选标签对应的第一概率平均值构成第二识别结果。
在一个可选的例子中,假设在一个教师网络模型中,与第i个未标记数据对应的备选标签包括6个,相应的,每个备选标签对应的候选概率同样包括6个,且6个候选概率的总和为1。
假设存在5个教师网络模型,那么就对应5个第一识别结果,每一个第一识别结果都包括两种数据,其中一种数据是6个备选标签,另一种数据是6个候选概率。
将5个教师网络模型中所输出的属于同一个备选标签的5个候选概率数据相叠加并求取平均值,即可获取一个第一概率平均值。因为有6个备选标签,自然也对应6个第一概率平均值。
而这6个第一概率平均值,以及6个备选标签,则构成了第二识别结果,每一个第一概率平均值,均作为与之对应的备选标签的预测概率。
在介绍与教师网络模型相关操作后,下面则介绍与学生网络相关的操作,具体参见如下实施例。
在上述任一实施例的基础上,本发明实施例还提供了另一种模型蒸馏方法,与上述任一实施例相同或类似的地方这里不再过多赘述。这里将详细说明对第i个未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果的具体执行过程,具体参见图4所示,该方法步骤包括:
步骤410,分别对第i个未标记数据进行强增广处理。
步骤420,将经过强增广处理后的未标记数据,输入到学生网络模型中,获取第三识别结果。
如上文对第i个未标记数据进行弱增强处理相类似的,这里执行强增广处理。
强增广处理例如包括剪切、改变图像尺寸等等。执行强增广处理,是进一步提升第i个未标记数据(图像)的泛化能力,以进一步提升模型的识别能力,使得学生网络模型能够接近多个教师网络模型集成的效果。
在另一个可选的实施例中,第三识别结果中包括:第i个未标记数据的至少一个备选标签中每一个备选标签对应的第二候选概率,与第i个未标记数据对应的目标标签。
根据第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,具体包括如下方法步骤,具体参见图5所示。
步骤510,根据第二识别结果中每一个备选标签对应的第一概率平均值,以及第三识别结果中每一个备选标签对应的第二候选概率,确定第一损失函数。
步骤520,根据伪标签以及目标标签,确定第二损失函数。
步骤530,根据第一损失函数和第二损失函数,对学生网络模型进行优化。
具体的,将每一个备选标签对应的第一概率平均值,以及第三识别结果中每一个备选标签对应的第二候选概率共同输入到预配置的计算损失的公式中,得到第一损失函数H(f,g)。类似的道理,可以将目标标签以及伪标签共同输入到配配置的另一个可以计算损失的公式中,得到第二损失函数H(p,q)。然后,根据第一损失函数和第二损失函数,对学生网络模型进行优化。
在一个可执行的具体例子中,优化操作可以包括根据第一损失函数和第二损失函数,调整学生网络模型中的参数。
可选的,确定学生网络模型达到某种预设标准可以参见如下:
根据第一损失函数和第二损失函数,确定目标损失;
当目标损失达到预设损失阈值时,确定学生网络模型符合预设标准。
具体的,目标损失为H(f,g)+γH(p,q)之和,当H(f,g)+ γH(p,q)之和趋近于最小化时,可以确定学生网络模型符合预设标准。
例如设定一个界限值,当H(f,g)+γH(p,q)等于或者小于界限值时,停止对网络模型优化。每次根据H(f,g)+γH(p,q)的结果不断调整学生网络模型中的参数。调整后利用下一个未标记数据,重复执行上文中的所有操作,然后再次计算H(f,g)+γH(p,q),直到学生网络模型达到预设标准,也即是H(f,g)+γH(p,q)等于或者小于界限值时结束。
图6为本发明提供的一种模型蒸馏方法的一个具体应用实例的方法流程示意图,具体参见图6所示,未标记数据分别执行弱增广处理后,输入到不同的教师网络模型模型中,每一个教师网络模型得到一个模型预测结果,也即是第一识别结果。然后,对所有的第一识别结果进行集成,得到第二识别结果,也即是图6中所显示的模型预测集成,并且获取模型预测类别(也即是伪标签)。未标记数据执行强增广后,输入到学生网络模型中,得到第三识别结果,包括每一个备选标签对应的第二候选概率,以及每一个备选标签,同时还获取到目标标签。
然后根据所有第二候选概率以及所有第一概率平均值得到第一损失函数,根据伪标签和目标标签,得到第二损失函数。最终,根据第一损失函数和第二损失函数,对学生网络模型进行优化(此步骤图中未显示)。具体的执行过程参见上文,这里不再过多赘述。
以上,为本申请所提供的多模型蒸馏几个方法实施例,下文中则介绍说明本申请所提供的多模型蒸馏其他实施例,具体参见如下。
图7为本发明实施例提供的一种多模型蒸馏装置,该装置包括:处理模块701、设定模块702,以及优化模块703。
处理模块701,用于对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果;根据一组第一识别结果,获取第二识别结果,其中第二识别结果包括第i个未标记数据的至少一个备选标签,以及每一个备选标签对应的预测概率;
设定模块702,用于当至少一个备选标签中第一备选标签的预测概率大于或者等于预设概率阈值时,将第一备选标签设定为未标记数据的伪标签,其中,第一备选标签为至少一个备选标签中概率最大的标签;
处理模块701,还用于对第i个未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果;
优化模块703,根据第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束,其中,i 为正整数。
可选的,处理模块701,具体用于:
分别对第i个未标记数据执行不同类型的弱增广处理;
将经过不同类型的弱增广处理后的未标记数据,对应输入到不同的教师网络模型中,获取一组第一识别结果。
可选的,第一识别结果包括与第i个未标记数据对应的至少一个备选标签,以及每一个备选标签对应的第一候选概率;
处理模块701,具体用于:分别获取一组第一识别结果中相同备选标签对应的第一候选概率的第一概率平均值;
第一识别结果中的至少一个备选标签,以及每一个备选标签对应的第一概率平均值构成第二识别结果,其中每一个备选标签对应的第一概率平均值,即为第二识别结果中与备选标签对应的预测概率。
可选的,处理模块701,具体用于:分别对第i个未标记数据进行强增广处理;
将经过强增广处理后的未标记数据,输入到学生网络模型中,获取第三识别结果。
可选的,第三识别结果中包括:第i个未标记数据的至少一个备选标签中每一个备选标签对应的第二候选概率,与第i个未标记数据对应的目标标签;
优化模块703,具体用于:根据第二识别结果中每一个备选标签对应的第一概率平均值,以及第三识别结果中每一个备选标签对应的第二候选概率,确定第一损失函数;
根据伪标签以及目标标签,确定第二损失函数;
根据第一损失函数和第二损失函数,对学生网络模型进行优化。
本发明实施例提供的一种多模型蒸馏装置中各部件所执行的功能均已在上述任一方法实施例中做了详细的描述,因此这里不再赘述。
本发明实施例提供的一种多模型蒸馏装置,对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,其中教师网络模型为已经训练好的教师网络模型。得到一组第一识别结果,然后根据一组第一识别结果,获取第二识别结果。在第二识别结果中,包括第i个未标记数据的至少一个备选标签,以及每一个备选标签对应的预测概率。当至少一个备选标签中的第一备选标签的预测概率大于或者等于预设概率阈值时,则可以认定第一备选标签为与第i个未标记数据对应的伪标签。然后,对第i个未标记数据进行第二泛化处理后,输入到学生网络模型,获取第三识别结果,将第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束。该过程中,考虑到同一个未标记数据对应的标签具有唯一性,且因为教师网络模型为训练好的网络模型,其最终得出的训练结果具有一定的有效性。因此,可以利用第二识别结果和伪标签,以及第三识别结果共同作为参考依据,对学生网络模型进行修正,重复执行上述过程,直至学生网络模型达到预设标准结束。因为对学生网络模型的训练是采用未标记数据,不涉及已标记数据,因此大大降低了人力标记的成本,而且同样可以达到与利用已标记数据训练学生网络模型相同或类似的效果。
如图8所示,本申请实施例提供了一种电子设备,包括处理器111、通信接口112、存储器113和通信总线114,其中,处理器111,通信接口 112,存储器113通过通信总线114完成相互间的通信。
存储器113,用于存放计算机程序;
在本申请一个实施例中,处理器111,用于执行存储器113上所存放的程序时,实现前述任意一个方法实施例提供的模型蒸馏方法,包括:
对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果;
根据一组第一识别结果,获取第二识别结果,其中第二识别结果包括第i个未标记数据的至少一个备选标签,以及每一个备选标签对应的预测概率;
当至少一个备选标签中第一备选标签的预测概率大于或者等于预设概率阈值时,将第一备选标签设定为未标记数据的伪标签,其中,第一备选标签为至少一个备选标签中概率最大的标签;
对第i个未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果;
根据第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,直至学生网络模型符合预设标准时结束,其中,i为正整数。
可选的,分别对第i个未标记数据执行不同类型的弱增广处理;
将经过不同类型的弱增广处理后的未标记数据,对应输入到不同的教师网络模型中,获取一组第一识别结果。
可选的,第一识别结果包括与第i个未标记数据对应的至少一个备选标签,以及每一个备选标签对应的第一候选概率;
根据一组第一识别结果,获取第二识别结果,具体包括:
分别获取一组第一识别结果中相同备选标签对应的第一候选概率的第一概率平均值;
第一识别结果中的至少一个备选标签,以及每一个备选标签对应的第一概率平均值构成第二识别结果,其中每一个备选标签对应的第一概率平均值,即为第二识别结果中与备选标签对应的预测概率。
可选的,分别对第i个未标记数据进行强增广处理;
将经过强增广处理后的未标记数据,输入到学生网络模型中,获取第三识别结果。
可选的,第三识别结果中包括:第i个未标记数据的至少一个备选标签中每一个备选标签对应的第二候选概率,与第i个未标记数据对应的目标标签;
根据第二识别结果、第三识别结果,以及伪标签,对学生网络模型进行优化,具体包括:
根据第二识别结果中每一个备选标签对应的第一概率平均值,以及第三识别结果中每一个备选标签对应的第二候选概率,确定第一损失函数;
根据伪标签以及目标标签,确定第二损失函数;
根据第一损失函数和第二损失函数,对学生网络模型进行优化。
可选的,对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果之前,方法还包括:
利用已标记数据分别对至少两个教师网络模型中的每一个教师网络模型进行训练。
可选的,根据第一损失函数和第二损失函数,确定目标损失;
当目标损失达到预设损失阈值时,确定学生网络模型符合预设标准。
本申请实施例还提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现如前述任意一个方法实施例提供的模型蒸馏方法的步骤。
需要说明的是,在本文中,诸如“第一”和“第二”等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括要素的过程、方法、物品或者设备中还存在另外的相同要素。
以上仅是本发明的具体实施方式,使本领域技术人员能够理解或实现本发明。对这些实施例的多种修改对本领域的技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所申请的原理和新颖特点相一致的最宽的范围。
Claims (10)
1.一种模型蒸馏方法,其特征在于,所述方法包括:
对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果;
根据一组所述第一识别结果,获取第二识别结果,其中第二识别结果包括第i个所述未标记数据的至少一个备选标签,以及每一个所述备选标签对应的预测概率;
当至少一个所述备选标签中第一备选标签的预测概率大于或者等于预设概率阈值时,将所述第一备选标签设定为所述未标记数据的伪标签,其中,所述第一备选标签为至少一个所述备选标签中概率最大的标签;
对第i个所述未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果;
根据所述第二识别结果、所述第三识别结果,以及所述伪标签,对所述学生网络模型进行优化,直至所述学生网络模型符合预设标准时结束,其中,i为正整数。
2.根据权利要求1所述的方法,其特征在于,对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果,具体包括:
分别对第i个所述未标记数据执行不同类型的弱增广处理;
将经过不同类型的弱增广处理后的未标记数据,对应输入到不同的教师网络模型中,获取一组所述第一识别结果。
3.根据权利要求1或2所述的方法,其特征在于,所述第一识别结果包括与第i个所述未标记数据对应的至少一个备选标签,以及每一个所述备选标签对应的第一候选概率;
所述根据一组所述第一识别结果,获取第二识别结果,具体包括:
分别获取一组所述第一识别结果中相同备选标签对应的第一候选概率的第一概率平均值;
所述第一识别结果中的至少一个备选标签,以及每一个备选标签对应的第一概率平均值构成所述第二识别结果,其中每一个备选标签对应的第一概率平均值,即为所述第二识别结果中与所述备选标签对应的预测概率。
4.根据权利要求1所述的方法,其特征在于,所述对第i个所述未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果,具体包括:
分别对第i个所述未标记数据进行强增广处理;
将经过强增广处理后的未标记数据,输入到所述学生网络模型中,获取所述第三识别结果。
5.根据权利要求1、2或4任一项所述的方法,其特征在于,所述第三识别结果中包括:第i个所述未标记数据的至少一个备选标签中每一个备选标签对应的第二候选概率,与第i个所述未标记数据对应的目标标签;
所述根据所述第二识别结果、所述第三识别结果,以及所述伪标签,对所述学生网络模型进行优化,具体包括:
根据所述第二识别结果中每一个备选标签对应的第一概率平均值,以及所述第三识别结果中每一个备选标签对应的第二候选概率,确定第一损失函数;
根据所述伪标签以及所述目标标签,确定第二损失函数;
根据所述第一损失函数和所述第二损失函数,对所述学生网络模型进行优化。
6.一种多模型蒸馏装置,其特征在于,所述装置包括:
处理模块,用于对第i个未标记数据进行第一泛化处理后,分别输入到至少两个教师网络模型中,获取一组第一识别结果;根据一组所述第一识别结果,获取第二识别结果,其中第二识别结果包括第i个所述未标记数据的至少一个备选标签,以及每一个所述备选标签对应的预测概率;
设定模块,用于当至少一个所述备选标签中第一备选标签的预测概率大于或者等于预设概率阈值时,将所述第一备选标签设定为所述未标记数据的伪标签,其中,所述第一备选标签为至少一个所述备选标签中概率最大的标签;
所述处理模块,还用于对第i个所述未标记数据进行第二泛化处理后,输入到学生网络模型中,获取第三识别结果;
优化模块,根据所述第二识别结果、所述第三识别结果,以及所述伪标签,对所述学生网络模型进行优化,直至所述学生网络模型符合预设标准时结束,其中,i为正整数。
7.根据权利要求6所述的装置,其特征在于,所述处理模块,具体用于:
分别对第i个所述未标记数据执行不同类型的弱增广处理;
将经过不同类型的弱增广处理后的未标记数据,对应输入到不同的教师网络模型中,获取一组所述第一识别结果。
8.根据权利要求6或7所述的装置,其特征在于,所述第一识别结果包括与第i个所述未标记数据对应的至少一个备选标签,以及每一个所述备选标签对应的第一候选概率;
所述处理模块,具体用于:分别获取一组所述第一识别结果中相同备选标签对应的第一候选概率的第一概率平均值;
所述第一识别结果中的至少一个备选标签,以及每一个备选标签对应的第一概率平均值构成所述第二识别结果,其中每一个备选标签对应的第一概率平均值,即为所述第二识别结果中与所述备选标签对应的预测概率。
9.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现权利要求1-5任一项所述的模型蒸馏方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1-5任一项所述的模型蒸馏方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111680113.7A CN114519416A (zh) | 2021-12-30 | 2021-12-30 | 模型蒸馏方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111680113.7A CN114519416A (zh) | 2021-12-30 | 2021-12-30 | 模型蒸馏方法、装置及电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114519416A true CN114519416A (zh) | 2022-05-20 |
Family
ID=81597722
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111680113.7A Pending CN114519416A (zh) | 2021-12-30 | 2021-12-30 | 模型蒸馏方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114519416A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115099988A (zh) * | 2022-06-28 | 2022-09-23 | 腾讯科技(深圳)有限公司 | 模型训练方法、数据处理方法、设备及计算机介质 |
-
2021
- 2021-12-30 CN CN202111680113.7A patent/CN114519416A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115099988A (zh) * | 2022-06-28 | 2022-09-23 | 腾讯科技(深圳)有限公司 | 模型训练方法、数据处理方法、设备及计算机介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
EP2806374B1 (en) | Method and system for automatic selection of one or more image processing algorithm | |
CN110490180B (zh) | 基于图像识别的作业批改方法、装置、存储介质及服务器 | |
CN110378235B (zh) | 一种模糊人脸图像识别方法、装置及终端设备 | |
CN109960725B (zh) | 基于情感的文本分类处理方法、装置和计算机设备 | |
JP7266674B2 (ja) | 画像分類モデルの訓練方法、画像処理方法及び装置 | |
EP3989104A1 (en) | Facial feature extraction model training method and apparatus, facial feature extraction method and apparatus, device, and storage medium | |
CN111598182B (zh) | 训练神经网络及图像识别的方法、装置、设备及介质 | |
CN113076994B (zh) | 一种开集域自适应图像分类方法及系统 | |
CN107516102B (zh) | 图像数据分类与建立分类模型方法、装置及系统 | |
CN110909224B (zh) | 一种基于人工智能的敏感数据自动分类识别方法及系统 | |
CN113010683B (zh) | 基于改进图注意力网络的实体关系识别方法及系统 | |
CN103824090A (zh) | 一种自适应的人脸低层特征选择方法及人脸属性识别方法 | |
CN111291773A (zh) | 特征识别的方法及装置 | |
CN110610230A (zh) | 一种台标检测方法、装置及可读存储介质 | |
US20220358658A1 (en) | Semi Supervised Training from Coarse Labels of Image Segmentation | |
CN110796210A (zh) | 一种标签信息的识别方法及装置 | |
CN113723070A (zh) | 文本相似度模型训练方法、文本相似度检测方法及装置 | |
CN110942063B (zh) | 证件文字信息获取方法、装置以及电子设备 | |
CN114519416A (zh) | 模型蒸馏方法、装置及电子设备 | |
CN110795997A (zh) | 基于长短期记忆的教学方法、装置和计算机设备 | |
CN113283388A (zh) | 活体人脸检测模型的训练方法、装置、设备及存储介质 | |
CN117173677A (zh) | 手势识别方法、装置、设备及存储介质 | |
CN116563604A (zh) | 端到端目标检测模型训练、图像目标检测方法及相关设备 | |
WO2023173546A1 (zh) | 文本识别模型的训练方法、装置、计算机设备及存储介质 | |
CN114241253A (zh) | 违规内容识别的模型训练方法、系统、服务器及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |