CN117237742B - 一种针对初始模型的知识蒸馏方法和装置 - Google Patents

一种针对初始模型的知识蒸馏方法和装置 Download PDF

Info

Publication number
CN117237742B
CN117237742B CN202311481966.7A CN202311481966A CN117237742B CN 117237742 B CN117237742 B CN 117237742B CN 202311481966 A CN202311481966 A CN 202311481966A CN 117237742 B CN117237742 B CN 117237742B
Authority
CN
China
Prior art keywords
model
pooling layer
average pooling
global average
initial
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202311481966.7A
Other languages
English (en)
Other versions
CN117237742A (zh
Inventor
沈艳梅
宿栋栋
刘伟
王彦伟
李仁刚
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Suzhou Metabrain Intelligent Technology Co Ltd
Original Assignee
Suzhou Metabrain Intelligent Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Suzhou Metabrain Intelligent Technology Co Ltd filed Critical Suzhou Metabrain Intelligent Technology Co Ltd
Priority to CN202311481966.7A priority Critical patent/CN117237742B/zh
Publication of CN117237742A publication Critical patent/CN117237742A/zh
Application granted granted Critical
Publication of CN117237742B publication Critical patent/CN117237742B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Landscapes

  • Image Analysis (AREA)

Abstract

本发明实施例提供了一种针对初始模型的知识蒸馏方法和装置,涉及模型知识蒸馏技术领域,通过本发明实施例,可以通过生成目标图像;采用初始模型基于目标图像生成初始学生模型,并将初始学生模型确定为第一教师模型;基于目标图像,获取初始模型预测值和第一GAP模型预测值;通过初始模型预测值和第一GAP模型预测值,计算第一分类正确率;基于第一分类正确率消除标签噪声,生成第一平滑软标签;采用第一教师模型基于第一平滑软标签,生成第一目标学生模型,从而实现对软标签进行加权平均来平滑软标签噪声,将平滑后软标签用于学生模型的蒸馏分类损失计算,不仅可以获取模型性能的提升,还能提高模型的鲁棒性。

Description

一种针对初始模型的知识蒸馏方法和装置
技术领域
本发明涉及模型知识蒸馏技术领域,特别是涉及一种针对初始模型的知识蒸馏方法、一种针对初始模型的知识蒸馏装置、一种电子设备以及一种计算机可读存储介质。
背景技术
知识蒸馏指的是将预训练好的教师模型的知识通过蒸馏的方式迁移至学生模型,通常教师模型会比学生模型网络容量更大、模型结构更复杂,学生模型通过学习教师模型的更可信的软标签来获取提升。
在实际应用中,仅通过人工标注标签无法反应真实的标签分布情况,导致训练效率低下。
发明内容
本发明实施例是提供一种针对初始模型的知识蒸馏方法、装置、电子设备以及计算机可读存储介质,以解决如何提升针对模型的知识蒸馏训练效率的问题。
本发明实施例公开了针对初始模型的知识蒸馏方法,所述初始模型包括卷积层,所述卷积层配置有对应的全局平均池化层模型,包括:
生成目标图像;
采用所述初始模型在第一迭代周期基于所述目标图像生成初始学生模型,并将所述初始学生模型确定为第二迭代周期的第一教师模型;
基于所述目标图像,获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值;
通过所述初始模型预测值和所述第一全局平均池化层模型预测值,计算针对所述初始模型和所述全局平均池化层模型的第一分类正确率;
基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;
采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型。
可选地,还包括:
将所述第一目标学生模型确定为第三迭代周期的第二教师模型;
获取针对所述第二教师模型的教师模型预测值和针对所述全局平均池化层模型的第二全局平均池化层模型预测值;
通过所述教师模型预测值和所述第二全局平均池化层模型预测值,计算针对所述教师模型和所述全局平均池化层模型的第二分类正确率;
基于所述第二分类正确率消除标签噪声,生成第二平滑软标签;
采用所述教师模型基于所述第二平滑软标签,在所述第三迭代周期生成第二目标学生模型。
可选地,还包括:
将最后一次迭代生成的学生模型确定为分类模型;
将所述目标图像输入所述分类模型,输出模型预测概率,并将最大的模型预测概率对应的结果确定为分类结果。
可选地,所述基于所述目标图像,获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值的步骤包括:
将所述目标图像作为输入图像,确定针对所述初始模型输出的初始模型预测概率向量,以及与针对所述初始模型预测概率向量对应的第一向量维度;
采用针对所述初始模型预测概率向量和所述第一向量维度,计算针对所述全局平均池化层模型的第一全局平均池化层模型预测概率向量;
通过所述初始模型预测概率向量确定针对所述初始模型的初始模型预测值,并通过所述第一全局平均池化层模型预测概率向量确定针对所述全局平均池化层模型的第一全局平均池化层模型预测值。
可选地,在所述基于所述第一分类正确率消除标签噪声,生成第一平滑软标签的步骤之前,还包括:
获取人工标注模型标签值;
计算所述初始模型预测值和所述人工标注模型标签值之间的第一交叉熵损失,并采用所述第一交叉熵损失确定针对所述初始模型的第一最小分类损失值;
计算所述第一全局平均池化层模型预测值和所述人工标注模型标签值之间的第二交叉熵损失,并采用所述第二交叉熵损失确定针对所述全局平均池化层模型的第二最小分类损失值。
可选地,还包括:
确定针对所述初始模型的初始模型权重;
采用所述第一最小分类损失值确定针对所述初始模型的初始模型偏置梯度,基于所述初始模型权重和所述初始模型偏置梯度更新模型参数;
确定针对所述全局平均池化层模型的第一全局平均池化层模型权重;
采用所述第二最小分类损失值确定针对所述全局平均池化层模型的第一全局平均池化层模型偏置梯度,基于所述第一全局平均池化层模型权重和所述第一全局平均池化层模型偏置梯度更新模型参数。
可选地,所述获取针对所述第二教师模型的教师模型预测值和针对所述全局平均池化层模型的第二全局平均池化层模型预测值的步骤包括:
将所述目标图像作为输入图像,确定针对所述第二教师模型输出的第二教师模型预测概率向量,以及与针对所述第二教师模型预测概率向量对应的第二向量维度;
采用针对所述第二教师模型预测概率向量和所述第二向量维度,计算针对所述全局平均池化层模型的第二全局平均池化层模型预测概率向量;
通过所述第二教师模型预测概率向量确定获取针对所述第二教师模型的教师模型预测值,并通过所述第二全局平均池化层模型预测概率向量确定针对所述全局平均池化层模型的第二全局平均池化层模型预测值。
可选地,在所述基于所述第二分类正确率消除标签噪声,生成第二平滑软标签的步骤之前,还包括:
计算所述第一平滑软标签和所述教师模型预测值之间的第一库尔巴克莱布勒散度,并采用所述第一库尔巴克莱布勒散度确定针对所述第二教师模型的第三最小分类损失值;
计算所述第一平滑软标签和第二全局平均池化层模型预测值之间的第二库尔巴克莱布勒散度,并采用所述第二库尔巴克莱布勒散度确定针对所述全局平均池化层模型的第四最小分类损失值。
可选地,所述采用所述第三交叉熵损失和所述第四交叉熵损失确定第二分类正确率的步骤包括:
确定针对所述第二教师模型的第二教师模型权重;
采用所述第三最小分类损失值确定针对所述第二教师模型的第二教师模型偏置梯度,基于所述第二教师模型权重和所述第二教师模型偏置梯度更新模型参数;
确定针对所述全局平均池化层模型的第二全局平均池化层模型权重;
采用所述第四最小分类损失值确定针对所述全局平均池化层模型的第二全局平均池化层模型偏置梯度,基于所述第二全局平均池化层模型权重和所述第二全局平均池化层模型偏置梯度更新模型参数。
可选地,所述生成目标图像的步骤包括:
获取初始图像集;
确定目标亮度和目标尺寸;
基于所述目标亮度和所述目标尺寸,对所述初始图像集进行归一化操作,生成目标图像。
可选地,所述归一化操作包括均值归一化。
可选地,所述归一化操作包括方差归一化。
可选地,阈值归一化。
可选地,还包括:
基于缺失值对所述初始图像集执行数据清洗操作。
可选地,还包括:
基于异常值对所述初始图像集执行数据清洗操作。
可选地,还包括:
基于噪声数据对所述初始图像集执行数据清洗操作。
可选地,还包括:
对所述初始图像集执行数据抽样操作;所述数据抽样操作包括:随机抽样,和/或,分层抽样,和/或,过采样和欠采样。
本发明实施例还公开了一种针对初始模型的知识蒸馏装置,所述初始模型包括卷积层,所述卷积层配置有对应的全局平均池化层模型,包括:
目标图像生成模块,用于生成目标图像;
第一教师模型确定模块,用于采用所述初始模型在第一迭代周期基于所述目标图像生成初始学生模型,并将所述初始学生模型确定为第二迭代周期的第一教师模型;
预测值获取模块,用于基于所述目标图像,获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值;
第一分类正确率计算模块,用于通过所述初始模型预测值和所述第一全局平均池化层模型预测值,计算针对所述初始模型和所述全局平均池化层模型的第一分类正确率;
第一平滑软标签生成模块,用于基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;
第一目标学生模型生成模块,用于采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型。
本发明实施例还公开了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,所述处理器、所述通信接口以及所述存储器通过所述通信总线完成相互间的通信;
所述存储器,用于存放计算机程序;
所述处理器,用于执行存储器上所存放的程序时,实现如本发明实施例所述的方法。
本发明实施例还公开了一种计算机可读存储介质,其上存储有指令,当由一个或多个处理器执行时,使得所述处理器执行如本发明实施例所述的方法。
本发明实施例包括以下优点:
本发明实施例,可以通过生成目标图像;采用所述初始模型在第一迭代周期基于所述目标图像生成初始学生模型,并将所述初始学生模型确定为第二迭代周期的第一教师模型;基于所述目标图像,获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值;通过所述初始模型预测值和所述第一全局平均池化层模型预测值,计算针对所述初始模型和所述全局平均池化层模型的第一分类正确率;基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型,在每个卷积层后增加GAP和全连接层以获取更泛化且多样的软标签,对软标签进行加权平均来平滑软标签噪声,将平滑后软标签用于学生模型的蒸馏分类损失计算,不仅可以获取模型性能的提升,还能提高模型的鲁棒性。
附图说明
图1是本发明实施例中提供的一种针对初始模型的知识蒸馏方法的步骤流程图;
图2是本发明实施例中提供的一种模型迭代流程示意图;
图3是本发明实施例中提供的一种针对GAP模型的结构示意图;
图4是本发明实施例中提供的一种针对初始模型的知识蒸馏装置的结构框图;
图5是本发明实施例中提供的一种电子设备的硬件结构框图;
图6是本发明实施例中提供的一种计算机可读介质的示意图。
具体实施方式
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图和具体实施方式对本发明作进一步详细的说明。
深度神经网络在各类应用场景中表现出显著的性能,然而强大的性能的也伴随着模型计算和参数量的爆炸式增长,不仅提高了模型部署的成本,而且导致模型过度拟合、泛化性差、学习效率低等问题。对此,一些模型剪枝、轻量化模型设计、知识蒸馏等方法被提出用于解决这个问题,其中,知识蒸馏是较为高效的方法之一。
知识蒸馏指的是将预训练好的教师模型的知识通过蒸馏的方式迁移至学生模型,通常教师模型会比学生模型网络容量更大、模型结构更复杂,学生模型通过学习教师模型的更可信的软标签来获取提升。知识蒸馏的架构包括在线、离线、自蒸馏等方式,蒸馏的知识类型包括输出响应概率、特征映射图、层间关联等。其中,自蒸馏是教师和学生模型使用相同的网络,属于在线蒸馏的一种特例,不仅能降低训练开销,还能实现更高的精度。
在相关技术采用的自蒸馏方法中,自蒸馏方案(Knowledge Distillation andLabel Smooth-ing Regularization和Regularizing Class-wise Predictions viaSelf-knowledge Distillation),采用标签平滑正则方法进行蒸馏,前者采用手动设计软标签以获取预期的标签正则化分布,后者结合相同类别的不同标签的分布一致性;手动设计教师模型的软标签无法反映真实标签分布情况;教师模型软标签包含噪声影响学生模型的性能。
多阶段蒸馏方法使用前面几次迭代训练的模型作为教师模型,蒸馏后面几次迭代训练的学生模型,具有训练开销小、精度高的特点。但是,由于教师模型的软标签存在不确定噪声,导致训练模型鲁棒性差、训练效率低问题。
本发明提出的基于标签平滑的图像分类模型自蒸馏方法,采用多阶段蒸馏的方式,通过修改模型网络结构获取更多样化的软标签,通过对软标签进行加权平均来获取较平滑的软标签,用于学生模型的蒸馏分类损失训练。本发明的方法不仅可以平滑软标签噪声,还能提高模型的泛化能力。
参照图1,示出了本发明实施例中提供的一种针对初始模型的知识蒸馏方法的步骤流程图,具体可以包括如下步骤:
步骤101,生成目标图像;
步骤102,采用所述初始模型在第一迭代周期基于所述目标图像生成初始学生模型,并将所述初始学生模型确定为第二迭代周期的第一教师模型;
步骤103,基于所述目标图像,获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值;
步骤104,通过所述初始模型预测值和所述第一全局平均池化层模型预测值,计算针对所述初始模型和所述全局平均池化层模型的第一分类正确率;
步骤105,基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;
步骤106,采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型。
在具体实现中,本发明实施例可以获取初始图像集,并对初始图像集进行预处理。
示例性地,预处理方法可以包括亮度规范化、滤除噪声、尺度归一化、裁剪等。
本发明实施例还可以构建自蒸馏模型,构建知识蒸馏模型主要包括:选择教师模型和学生模型、选择蒸馏知识类型、确定蒸馏策略、配置蒸馏损失函数和修改初始模型的层次结构等环节。
其中,蒸馏策略可以采用多阶段自蒸馏的框架,参考图2,图2是本发明实施例中提供的一种模型迭代流程示意图,蒸馏策略是使用前一个迭代中的模型作为教师模型来蒸馏后一个迭代中的学生模型,那么前一次迭代中的学生模型也是后一次迭代中的教师模型。
针对如何选选择教师模型和学生模型,本发明实施例可以采用初始模型在第一迭代周期基于目标图像生成初始学生模型,该初始学生模型可以是为经人工指导,经由初始模型自行训练生成的模型,在生成初始学生模型后,可以将初始学生模型确定为第二迭代周期的第一教师模型。可以理解,第一教师模型可以在第二迭代中期中用于训练生成第一目标学生模型,而第一目标学生模型则可以在第三迭代周期中作为第三迭代周期的教师模型。
参考图3,图3是本发明实施例中提供的一种针对GAP模型的结构示意图。
针对如何修改初始模型的层次结构,本发明实施例可以在卷积层后面增加GAP(global average pooling)层和全连接层,作为全局平均池化层模型,即,GAP模型。
在具体实现中,仅通过人工标注标签无法反应真实的标签分布情况,并且,若不对训练标签进行降噪处理,可能会导致训练效率低下。
在具体实现中,本发明实施例可以基于目标图像,获取针对初始模型的初始模型预测值,以及,针对全局平均池化层模型的第一全局平均池化层模型预测值。通过初始模型预测值和第一全局平均池化层模型预测值,计算针对初始模型和全局平均池化层模型的第一分类正确率。
示例性地,可以将目标图像作为输入,使初始模型和GAP模型在第一迭代周期分别输出初始模型预测值和第一全局平均池化层模型预测值,具体地,可以通过前向传播计算获取初始模型的初始模型预测值以及每个卷积层的GAP第一全局平均池化层模型预测值,然后计算初始模型和GAP模型的分类损失,并通过反向传播更新初始模型和所有GAP模型的参数,完成一次全部数据集的迭代训练后,计算初始模型和每一层卷积层的GAP模型的分类正确率,将其作为第一分离正确率。
在实际应用中,经过降噪处理的标签可以被称为平滑软标签,在本发明实施例中,第一分离正确率,可以用于后续生成平滑软标签,本发明实施例可以基于第一分类正确率消除标签噪声,生成第一平滑软标签,第一平滑软标签可以是针对第二迭代周期的第一教师模型的训练标签。
示例性地,可以基于每个模型的分类正确率,分别对初始模型预测值、每一层卷积层的GAP模型的第一全局平均池化层模型预测值和由人工标记的人工标注模型标签值进行加权平均,所获取的软标签即为第一平滑软标签,第一平滑软标签可以用于下一次迭代(第二迭代周期)中,学生模型的分类损失计算。
在生成第一平滑软标签后,本发明实施例可以采用第一教师模型基于第一平滑软标签,在第二迭代周期生成第一目标学生模型。
本发明实施例,可以通过生成目标图像;采用所述初始模型在第一迭代周期基于所述目标图像生成初始学生模型,并将所述初始学生模型确定为第二迭代周期的第一教师模型;基于所述目标图像,获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值;通过所述初始模型预测值和所述第一全局平均池化层模型预测值,计算针对所述初始模型和所述全局平均池化层模型的第一分类正确率;基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型,在每个卷积层后增加GAP和全连接层以获取更泛化且多样的软标签,对软标签进行加权平均来平滑软标签噪声,将平滑后软标签用于学生模型的蒸馏分类损失计算,不仅可以获取模型性能的提升,还能提高模型的鲁棒性。
在上述实施例的基础上,提出了上述实施例的变型实施例,在此需要说明的是,为了使描述简要,在变型实施例中仅描述与上述实施例的不同之处。
在本发明的一个可选地实施例中,还包括:
将所述第一目标学生模型确定为第三迭代周期的第二教师模型;
获取针对所述第二教师模型的教师模型预测值和针对所述全局平均池化层模型的第二全局平均池化层模型预测值;
通过所述教师模型预测值和所述第二全局平均池化层模型预测值,计算针对所述教师模型和所述全局平均池化层模型的第二分类正确率;
基于所述第二分类正确率消除标签噪声,生成第二平滑软标签;
采用所述教师模型基于所述第二平滑软标签,在所述第三迭代周期生成第二目标学生模型。
在实际应用中,本发明实施例同样可以采用多阶段蒸馏的方式实现蒸馏训练以更进一步地提升蒸馏训练效率,在具体实现中,第三迭代周期可以是第二迭代周期后的迭代周期,在第三迭代周期中,可以将第二迭代周期中生成的第一目标学生模型确定为第二教师模型。
示例性地,为了在第三迭代周期中获得平滑软标签,可以使第二教师模型和全局平均池化层模型分别输出教师模型预测值和第二全局平均池化层模型预测值,具体地,可以通过前向传播计算获取针对第二教师模型的教师模型预测值和针对GAP模型的第二全局平均池化层模型预测值,然后计算第二教师模型和GAP模型的分类损失,并通过反向传播更新第二教师模型和所有GAP模型的参数,完成一次全部数据集的迭代训练后,计算第二教师模型和每一层卷积层的GAP模型的分类正确率,将其作为第二分离正确率,然后则可以基于第二分类正确率消除标签噪声,生成针对第二教师模型在第三迭代周期的第二平滑软标签,并采用教师模型基于第二平滑软标签,在第三迭代周期生成第二目标学生模型。
示例性地,平滑软标签可以通过如下方式生成,例如,迭代次数记为,第次迭代训练的教师模型中的初始模型记为,第层GAP模型记为,样本输入到初始模型的预测概率向量记为,输入到教师模型的层GAP模型的预测概率向量记为,样本的标记标签(硬标签)记为,教师模型的初始模型的分类正确率记为,第层卷积层对应的GAP模型的分类正确率记为。采用加权平均方法对教师模型的标签进行权重的加权平均得到平滑软标签。
其中,软标签平滑的权重的计算方法如下:
标签的权重:硬标签是人工标记的标签,默认权重为1.0,
标签的权重:初始模型所输出软标签的权重,取决于图像分类是否正确和模型分类正确率,若样本分类正确,软标签权重为模型的分类正确率,反之,权重为0,并且采用超参数来控制教师模型的标签对平滑软标签的影响,一般的取值配置取值越大,对于平滑软标签的影响越小。
标签的权重:教师模型的GAP模型所输出软标签的权重,取决于图像分类是否正确和GAP模型分类正确率,若样本分类正确,对应软标签权重为GAP模型的分类正确率,反之,权重为0,并采用超参数来控制教师模型的标签对平滑软标签的影响,一般的取值配置取值越大,对于平滑软标签的影响越小。
在获取第次迭代训练的教师模型的每个样本的平滑软标签之后,将其用于第次迭代的学生模型的分类损失函数计算中,计算软标签分类损失。
后续的迭代周期都可以按照第三迭代周期的方式生成学生模型,对此不再赘述。
可选地,本发明实施例可以将最后一次迭代生成的学生模型确定为分类模型;将所述目标图像输入所述分类模型,输出模型预测概率,并将最大的模型预测概率对应的结果确定为分类结果。
在具体实现中,本发明实施例在完成所有迭代次数的模型训练后,可以将最后一次迭代生成的学生模型确定为分类模型,并可以将图像输入到分类模型,经过前向传播计算,输出模型预测概率,概率值最大即为分类结果,从而实现对图形的分类。
在本发明的一个可选地实施例中,所述基于所述目标图像,获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值的步骤包括:
将所述目标图像作为输入图像,确定针对所述初始模型输出的初始模型预测概率向量,以及与针对所述初始模型预测概率向量对应的第一向量维度;
采用针对所述初始模型预测概率向量和所述第一向量维度,计算针对所述全局平均池化层模型的第一全局平均池化层模型预测概率向量;
通过所述初始模型预测概率向量确定针对所述初始模型的初始模型预测值,并通过所述第一全局平均池化层模型预测概率向量确定针对所述全局平均池化层模型的第一全局平均池化层模型预测值。
在具体实现中,模型的预测概率向量可以用于后续计算分离损失。
示例性地,前向传播是从输入图像到输出模型预测值的过程,将输入图像记为,初始模型输出的初始模型预测概率向量记为,维度为中的每个值表示对应类别的初始模型预测值,表示初始模型从前到后的前向计算过程,。前向传播计算过程可看作卷积层的特征提取和全连接层的分类器模型的组合,将第层卷积层输出的特征映射图记为,第一向量维度记为,分别表示卷积层输出图像的宽、高和输出图像的数目。
GAP模型包括一个GAP层(全局平均层)和一个FC层(全连接层),其中,GAP层是对图像进行全局平均池化,即计算图像的平均值。将卷积层输出的特征映射图输入到对应的GAP模型中,将第层卷积层连接的GAP模型记为,GAP模型预测概率向量记为,那么,从而可以根据第一全局平均池化层模型预测概率向量确定出第一全局平均池化层模型预测值。
在本发明的一个可选地实施例中,在所述基于所述第一分类正确率消除标签噪声,生成第一平滑软标签的步骤之前,还包括:
获取人工标注模型标签值;
计算所述初始模型预测值和所述人工标注模型标签值之间的第一交叉熵损失,并采用所述第一交叉熵损失确定针对所述初始模型的第一最小分类损失值;
计算所述第一全局平均池化层模型预测值和所述人工标注模型标签值之间的第二交叉熵损失,并采用所述第二交叉熵损失确定针对所述全局平均池化层模型的第二最小分类损失值。
本发明实施例可以获取人工标注模型标签值,计算初始模型预测值和人工标注模型标签值之间的第一交叉熵损失,并采用第一交叉熵损失确定针对初始模型的第一最小分类损失值;计算第一全局平均池化层模型预测值和人工标注模型标签值之间的第二交叉熵损失,并采用第二交叉熵损失确定针对全局平均池化层模型的第二最小分类损失值。
示例性地,不同的迭代次数对应的分类损失计算方法不一样,迭代次数记为为样本的类别的人工标记值(硬标签),表示第次迭代的初始模型对样本的类别的预测值,表示第次迭代的第层卷积层GAP模型的预测值,表示类别数目。
时,第一次迭代的模型没有教师模型,只能根据人工标记的硬标签学习,采用交叉熵损失函数计算初始模型预测值和硬标签(人工标注模型标签值)之间的交叉熵损失
同理,还可以计算所述第一全局平均池化层模型预测值和所述人工标注模型标签值之间的第二交叉熵损失,即,计算GAP模型预测值和硬标签之间的交叉熵损失
在计算出的第一交叉熵损失和第二交叉熵损失后,可以基于第一交叉熵损失和第二交叉熵损失计算第一最小分类损失值和第二最小分类损失值。
示例性地,可以先对确定权重α。
其中,为第一最小分类损失值,为第二最小分类损失值。
在本发明的一个可选地实施例中,还包括:
确定针对所述初始模型的初始模型权重;
采用所述第一最小分类损失值确定针对所述初始模型的初始模型偏置梯度,基于所述初始模型权重和所述初始模型偏置梯度更新模型参数;
确定针对所述全局平均池化层模型的第一全局平均池化层模型权重;
采用所述第二最小分类损失值确定针对所述全局平均池化层模型的第一全局平均池化层模型偏置梯度,基于所述第一全局平均池化层模型权重和所述第一全局平均池化层模型偏置梯度更新模型参数。
在具体实现中,反向传播是模型训练过程中从输出预测概率值到输入图像的计算过程,根据图像的分类损失,即,第一最小分类损失值和第二最小分类损失,采用随机梯度下降等模型优化方法,计算初始模型和全局平均池化层模型权重和偏置的梯度并更新模型参数,以获取最小的分类损失。需要说明的是,初始模型的反向传播过程更新所有卷积层和全连接层的权重和偏置参数,GAP模型的反向传播过程只更新对应全连接层的权重和偏置参数。
假设蒸馏训练需要经过t次迭代,完成一次全部数据集的迭代训练后,计算初始模型和每一层卷积层GAP模型的分类正确率,将第次迭代后初始模型的分类正确率记为,将第次迭代后第层卷积层对应的GAP模型的分类正确率记为,用于下一步的标签平滑。
在本发明的一个可选地实施例中,所述获取针对所述第二教师模型的教师模型预测值和针对所述全局平均池化层模型的第二全局平均池化层模型预测值的步骤包括:
将所述目标图像作为输入图像,确定针对所述第二教师模型输出的第二教师模型预测概率向量,以及与针对所述第二教师模型预测概率向量对应的第二向量维度;
采用针对所述第二教师模型预测概率向量和所述第二向量维度,计算针对所述全局平均池化层模型的第二全局平均池化层模型预测概率向量;
通过所述第二教师模型预测概率向量确定获取针对所述第二教师模型的教师模型预测值,并通过所述第二全局平均池化层模型预测概率向量确定针对所述全局平均池化层模型的第二全局平均池化层模型预测值。
前向传播是从输入图像到输出模型预测值的过程,将输入图像记为,第二教师模型输出的第二教师模型预测概率向量记为,维度为中的每个值表示对应类别的分类概率,表示初始模型从前到后的前向计算过程,。前向传播计算过程可看作卷积层的特征提取和全连接层的分类器模型的组合,将第层卷积层输出的特征映射图记为,第二向量维度记为,分别表示卷积层输出图像的宽、高和输出图像的数目。
GAP模型包括一个GAP层(全局平均层)和一个FC层(全连接层),其中,GAP层是对图像进行全局平均池化,即计算图像的平均值。将卷积层输出的特征映射图输入到对应的GAP模型中,将第层卷积层连接的GAP模型记为,第二全局平均池化层模型预测概率向量记为,那么,在确定出第二全局平均池化层模型预测概率向量后,则可以通过第二全局平均池化层模型预测概率向量确定出第二全局平均池化层模型预测值。
在本发明的一个可选地实施例中,在所述基于所述第二分类正确率消除标签噪声,生成第二平滑软标签的步骤之前,还包括:
计算所述第一平滑软标签和所述教师模型预测值之间的第一库尔巴克莱布勒散度,并采用所述第一库尔巴克莱布勒散度确定针对所述第二教师模型的第三最小分类损失值;
计算所述第一平滑软标签和第二全局平均池化层模型预测值之间的第二库尔巴克莱布勒散度,并采用所述第二库尔巴克莱布勒散度确定针对所述全局平均池化层模型的第四最小分类损失值。
在具体实现中,除了第一次迭代以外,其他迭代模型都有教师模型,那么对应的图像分类损失除了硬标签分类损失外,还包含了教师模型的软标签分类损失,可以采用KL散度函数计算初始模型的输出预测值和教师模型的输出预测值之间的偏差。
示例性地,计算所述初始模型预测值和所述教师模型预测值之间的第三交叉熵损失可以通过如下方式实现。
迭代次数记为为样本的类别的人工标记值(硬标签),表示第次迭代的初始模型对样本的类别的预测值,表示第次迭代的第层卷积层GAP模型的预测值,表示类别数目。
时,第一次迭代的模型没有教师模型,只能根据人工标记的硬标签学习,采用交叉熵损失函数计算初始模型预测值和硬标签(人工标注模型标签值)之间的第一交叉熵损失
同理,还可以计算所述第一全局平均池化层模型预测值和所述人工标注模型标签值之间的第二交叉熵损失,即,计算GAP模型预测值和硬标签之间的第二交叉熵损失
时,学生模型有教师模型,图像分类损失包括硬标签分类损失和教师模型的软标签分类损失,采用KL散度函数计算模型的输出预测值和教师模型的输出软标签值之间的偏差。教师模型为第次迭代的模型,将教师模型输出的平滑后软标签记为,确定的权重超参数,计算初始模型的预测值和平滑软标签之间的第一库尔巴克莱布勒散度和总的第三最小分类损失值
同理,可以计算初第二全局平均池化层模型预测值和教师模型预测值之间的第四交叉熵损失,即计算GAP模型预测值和教师模型的平滑软标签之间的第二库尔巴克莱布勒散度和第四最小分类损失值
图像分类类别为模型输出概率向量中的最大值对应下标,以初始模型为例,模型预测的样本的类别为的最大值对应下标,记为,真实类别为人工标记的标签向量中值为1.0对应下标,记为,如果,则分类正确。
在本发的一个可选地实施例中,还包括:
确定针对所述第二教师模型的第二教师模型权重;
采用所述第三最小分类损失值确定针对所述第二教师模型的第二教师模型偏置梯度,基于所述第二教师模型权重和所述第二教师模型偏置梯度更新模型参数;
确定针对所述全局平均池化层模型的第二全局平均池化层模型权重;
采用所述第四最小分类损失值确定针对所述全局平均池化层模型的第二全局平均池化层模型偏置梯度,基于所述第二全局平均池化层模型权重和所述第二全局平均池化层模型偏置梯度更新模型参数。
在具体实现中,反向传播的计算过程,是采用随机梯度下降等模型优化方法最小化图像分类损失,可以根据图像分类损失,即,第三最小分类损失值和第四最小分类损失值,计算第二教师模型和全局平均池化层模型权重和偏置的梯度并更新模型参数。第二教师模型的反向传播过程更新所有卷积层和全连接层的权重和偏置参数,GAP模型的反向传播过程只更新对应的全连接层的权重和偏置参数。
完成一次全部数据集的迭代训练后,计算第二教师模型和每一层卷积层GAP模型的分类正确率,将第次迭代后初始模型的分类正确率记为,将第次迭代后第层卷积层对应的GAP模型的分类正确率记为,用于下一步的标签平滑。
可选地,所述分类模型不包含所述全局平均池化层模型。
在具体实现中,完成所有次数的蒸馏训练后,可以去掉所有的GAP模型结构,保留初始模型作为分类模型,并将样本图片输入到经知识蒸馏的初始模型,输出预测概率向量,概率值最大对应的下标为分类类别。
可选地,所述生成目标图像的步骤包括:
获取初始图像集;
确定目标亮度和目标尺寸;
基于所述目标亮度和所述目标尺寸,对所述初始图像集进行归一化操作,生成目标图像。
在具体实现中,本发明实施例可以对获取到的初始图像集进行预处理,具体包括:确定目标亮度和目标尺寸,并基于目标亮度和所述目标尺寸,对初始图像集进行归一化操作,生成目标图像。
可选地,所述归一化操作包括均值归一化。
均值归一化:将目标亮度和目标尺寸减去均值,使得数据的均值为0。
可选地,所述归一化操作包括方差归一化。
方差归一化:将目标亮度和目标尺寸除以标准差,使得数据的方差为1。
可选地,阈值归一化。
阈值归一化:将目标亮度和目标尺寸缩放到0到1的范围内。
可选地,还包括:
基于缺失值对所述初始图像集执行数据清洗操作。
可选地,还包括:
基于异常值对所述初始图像集执行数据清洗操作。
可选地,还包括:
基于噪声数据对所述初始图像集执行数据清洗操作。
除了将数据进行归一化操作,数据清洗也是大规模训练数据清理的重要步骤之一。在数据采集过程中,可能会存在一些错误、噪声或者异常值,这些数据可能会对模型的训练造成干扰,降低模型的准确性。因此,需要对数据进行清洗,去除这些错误数据。
数据清洗可以通过以下几种方法来实现:
缺失值处理:对于存在缺失值的样本,可以选择删除或者填充缺失值。删除缺失值可能会导致数据量的减少,但可以避免对模型的干扰。而填充缺失值可以通过均值、中位数等方法进行。
异常值处理:对于存在异常值的样本,可以选择删除或者修正异常值。删除异常值可能会导致数据量的减少,但可以避免对模型的干扰。修正异常值可以通过替换为均值、中位数等方法进行。
噪声数据处理:对于存在噪声数据的样本,可以选择删除或者平滑噪声数据。删除噪声数据可能会导致数据量的减少,但可以避免对模型的干扰。平滑噪声数据可以通过滤波等方法进行。
可选地,对所述初始图像集执行数据抽样操作;所述数据抽样操作包括:随机抽样,和/或,分层抽样,和/或,过采样和欠采样。
在进行大规模模型的训练时,由于数据量庞大,可能会导致训练时间过长或者资源消耗过大。因此,可以通过数据抽样的方法,从大规模数据集中随机选择一部分样本进行训练,以减少训练时间和资源消耗。
针对数据的抽样操作可以通过以下几种方法来实现:
随机抽样:从数据集中随机选择一部分样本进行训练。
分层抽样:根据样本的类别进行分层抽样,保证每个类别的样本在抽样后的数据集中的比例与原始数据集中的比例一致。
过采样和欠采样:对于数据不平衡的情况,可以通过过采样和欠采样的方法来调整样本的比例,使得各个类别的样本数量平衡。
为使本领域技术人员更好地理解本发明实施例,以下采用一完整示例对本发明实施例进行说明。
Step1:获取图像,并对图像预处理;
图像数据通常来源于采集装置和公共数据集,预处理方法主要包括亮度规范化、尺度归一化、裁剪、翻转等。
Step2:构建蒸馏模型,修改模型层次结构,在卷积层后面增加GAP层和全连接层;
自蒸馏模型构建:
1)采用多阶段自蒸馏的框架,蒸馏策略是使用前一个迭代中的模型作为教师模型来蒸馏后一个迭代中的学生模型;
2)蒸馏知识选择模型的预测标签值,教师模型输出的模型预测值,经过标签平滑后,输入到学生模型的分类损失函数中;
3)修改模型网络结构,在初始模型的每一个卷积层后面增加一个GAP层(全局平均池化层)和FC层(全连接层),称其为GAP模型,GAP模型输出为分类预测值,GAP模型只训练FC层的参数;
4)学生模型的分类损失包括硬标签分类损失和教师模型软标签分类损失;
Step3:教师/学生模型的训练计算,通过前向传播计算获取初始模型的预测值、每个卷积层的GAP模型预测值,计算初始模型和GAP模型的分类损失,通过反向传播更新初始模型和所有GAP模型的参数,完成一次全部数据集的迭代训练后,计算初始模型和每一层卷积层GAP模型的分类正确率;
自蒸馏模型构建完成后开始蒸馏训练,训练过程包括:前向传播计算获取模型预测值、计算图像分类损失和反向传播计算更新模型参数,直到完成一次全部样本的迭代训练。
1)前向传播计算
输入图像为表示初始模型从前到后的前向计算,输出模型预测概率向量为,维度为中的每个值表示对应类别的分类概率,第层卷积层输出的特征映射图记为,维度记为,分别表示卷积层输出图像的宽、高和输出图像的数目。
GAP模型包括两层:GAP层和FC层,其中,GAP层是计算图像的平均值,将卷积层输出的特征映射图输入到对应的GAP模型中,将第层卷积层连接的GAP模型记为,模型预测概率向量记为,那么
2)计算分类损失和图像分类
不同的迭代次数对应的分类损失计算方法不一样,迭代次数记为为样本的类别的人工标记值(硬标签),表示第次迭代的初始模型对样本的类别的预测值,表示第次迭代的第层卷积层GAP模型的预测值,表示类别数目。
时,第一次迭代的模型没有教师模型,只能根据人工标记的硬标签学习,采用交叉熵损失函数计算初始模型预测值和硬标签之间的交叉熵损失
同理,计算GAP模型预测值和硬标签之间的交叉熵损失
时,学生模型有教师模型,图像分类损失包括硬标签分类损失和教师模型的软标签分类损失,采用KL散度函数计算模型的输出预测值和教师模型的输出软标签值之间的偏差。教师模型为第次迭代的模型,将教师模型输出的平滑后软标签记为表示的权重超参数,计算初始模型的预测值和平滑软标签之间的KL散度和总的损失
同理,计算GAP模型预测值和教师模型的平滑软标签之间的KL散度和总损失
图像分类类别为模型输出概率向量中的最大值对应下标,以初始模型为例,模型预测的样本的类别为的最大值对应下标,记为,真实类别为人工标记的标签向量中值为1.0对应下标,记为,如果,则分类正确。
3)反向传播计算
反向传播的计算过程,是采用随机梯度下降等模型优化方法最小化图像分类损失,计算模型权重和偏置的梯度并更新模型参数。初始模型的反向传播过程更新所有卷积层和全连接层的权重和偏置参数,GAP模型的反向传播过程只更新对应的全连接层的权重和偏置参数。
完成一次全部数据集的迭代训练后,计算初始模型和每一层卷积层GAP模型的分类正确率,将第次迭代后初始模型的分类正确率记为,将第次迭代后第层卷积层对应的GAP模型的分类正确率记为,用于下一步的标签平滑。
Step4:标签平滑,基于每个模型的分类正确率,分别对初始模型预测值、每一层卷积层的GAP模型预测值和标记标签值进行加权平均,所获取的软标签为平滑软标签,将其用于下一次迭代中学生模型的分类损失计算;
在完成一次全部数据集的迭代训练后,基于教师模型的每个模型的分类正确率,分别对教师模型的初始模型预测值、每一层卷积层的GAP模型预测值和标记标签值进行加权平均,获取平滑软标签,用于下一次迭代中学生模型的分类损失计算。
次迭代训练的教师模型中的初始模型记为,第层GAP模型记为,样本输入到教师模型的初始模型的预测概率向量记为,输入到教师模型的层GAP模型的预测概率向量记为,样本的标记标签(硬标签)记为,教师模型的初始模型的分类正确率记为,第层卷积层对应的GAP模型的分类正确率记为。采用加权平均方法对教师模型的标签进行权重的加权平均得到平滑软标签,
其中,平滑软标签的权重的计算方法如下:
标签的权重
硬标签是人工标记的标签,默认权重为1.0,
标签的权重
教师模型的初始模型所输出软标签的权重,取决于图像分类是否正确和模型分类正确率,采用超参数来控制教师模型的标签对平滑软标签的影响,配置
标签的权重
教师模型的GAP模型所输出软标签的权重,取决于图像分类是否正确和GAP模型分类正确率,超参数来控制教师模型的标签对平滑软标签的影响,
Step5:完成所有迭代次数的模型训练后,将图像输入到分类模型,经过前向传播计算,输出模型预测概率,概率值最大即为分类结果;
完成所有次数的蒸馏训练后,去掉所有的GAP模型结构,保留初始模型,将样本图片输入到初始模型,输出预测概率向量,概率值最大对应的下标为分类类别。
通过上述方式实现对初始模型进行知识蒸馏1)通过修改模型网络结构,增加GAP层来获取更多样化的软标签;2)在蒸馏训练中采用一种标签平滑方法处理教师模型输出的软标签,获取具有泛化性的软标签,将其用于学生模型的蒸馏分类损失训练。本发明的方法,修改模型结构不会产生较大训练开销,而且可以平滑软标签噪声,提高模型的泛化能力,进而提高模型的鲁棒性。
需要说明的是,对于方法实施例,为了简单描述,故将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本发明实施例并不受所描述的动作顺序的限制,因为依据本发明实施例,某些步骤可以采用其他顺序或者同时进行。其次,本领域技术人员也应该知悉,说明书中所描述的实施例均属于优选实施例,所涉及的动作并不一定是本发明实施例所必须的。
参照图4,示出了本发明实施例中提供的一种针对初始模型的知识蒸馏装置的结构框图,具体可以包括如下模块:
目标图像生成模块401,用于生成目标图像;
第一教师模型确定模块402,用于采用所述初始模型在第一迭代周期基于所述目标图像生成初始学生模型,并将所述初始学生模型确定为第二迭代周期的第一教师模型;
预测值获取模块403,用于基于所述目标图像,获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值;
第一分类正确率计算模块404,用于通过所述初始模型预测值和所述第一全局平均池化层模型预测值,计算针对所述初始模型和所述全局平均池化层模型的第一分类正确率;
第一平滑软标签生成模块405,用于基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;
第一目标学生模型生成模块406,用于采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型。
对于装置实施例而言,由于其与方法实施例基本相似,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
另外,本发明实施例还提供了一种电子设备,包括:处理器,存储器,存储在存储器上并可在处理器上运行的计算机程序,该计算机程序被处理器执行时实现上述针对初始模型的知识蒸馏方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。
本发明实施例还提供了一种计算机可读存储介质,计算机可读存储介质上存储有计算机程序,计算机程序被处理器执行时实现上述针对初始模型的知识蒸馏方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。其中,所述的计算机可读存储介质,如只读存储器(Read-Only Memory,简称ROM)、随机存取存储器(Random AccessMemory,简称RAM)、磁碟或者光盘等。
图5为实现本发明各个实施例的一种电子设备的硬件结构示意图。
该电子设备500包括但不限于:射频单元501、网络模块502、音频输出单元503、输入单元504、传感器505、显示单元506、用户输入单元507、接口单元508、存储器509、处理器510、以及电源511等部件。本领域技术人员可以理解,图5中示出的电子设备结构并不构成对电子设备的限定,电子设备可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。在本发明实施例中,电子设备包括但不限于手机、平板电脑、笔记本电脑、掌上电脑、车载终端、可穿戴设备、以及计步器等。
应理解的是,本发明实施例中,射频单元501可用于收发信息或通话过程中,信号的接收和发送,具体的,将来自基站的下行数据接收后,给处理器510处理;另外,将上行的数据发送给基站。通常,射频单元501包括但不限于天线、至少一个放大器、收发信机、耦合器、低噪声放大器、双工器等。此外,射频单元501还可以通过无线通信系统与网络和其他设备通信。
电子设备通过网络模块502为用户提供了无线的宽带互联网访问,如帮助用户收发电子邮件、浏览网页和访问流式媒体等。
音频输出单元503可以将射频单元501或网络模块502接收的或者在存储器509中存储的音频数据转换成音频信号并且输出为声音。而且,音频输出单元503还可以提供与电子设备500执行的特定功能相关的音频输出(例如,呼叫信号接收声音、消息接收声音等等)。音频输出单元503包括扬声器、蜂鸣器以及受话器等。
输入单元504用于接收音频或视频信号。输入单元504可以包括图形处理器(Graphics Processing Unit,GPU)5041和麦克风5042,图形处理器5041对在视频捕获模式或图像捕获模式中由图像捕获装置(如摄像头)获得的静态图片或视频的图像数据进行处理。处理后的图像帧可以显示在显示单元506上。经图形处理器5041处理后的图像帧可以存储在存储器509(或其它存储介质)中或者经由射频单元501或网络模块502进行发送。麦克风5042可以接收声音,并且能够将这样的声音处理为音频数据。处理后的音频数据可以在电话通话模式的情况下转换为可经由射频单元501发送到移动通信基站的格式输出。
电子设备500还包括至少一种传感器505,比如光传感器、运动传感器以及其他传感器。具体地,光传感器包括环境光传感器及接近传感器,其中,环境光传感器可根据环境光线的明暗来调节显示面板5061的亮度,接近传感器可在电子设备500移动到耳边时,关闭显示面板5061和/或背光。作为运动传感器的一种,加速计传感器可检测各个方向上(一般为三轴)加速度的大小,静止时可检测出重力的大小及方向,可用于识别电子设备姿态(比如横竖屏切换、相关游戏、磁力计姿态校准)、振动识别相关功能(比如计步器、敲击)等;传感器505还可以包括指纹传感器、压力传感器、虹膜传感器、分子传感器、陀螺仪、气压计、湿度计、温度计、红外线传感器等,在此不再赘述。
显示单元506用于显示由用户输入的信息或提供给用户的信息。显示单元506可包括显示面板5061,可以采用液晶显示器(Liquid Crystal Display,LCD)、有机发光二极管(Organic Light-Emitting Diode, OLED)等形式来配置显示面板5061。
用户输入单元507可用于接收输入的数字或字符信息,以及产生与电子设备的用户设置以及功能控制有关的键信号输入。具体地,用户输入单元507包括触控面板5071以及其他输入设备5072。触控面板5071,也称为触摸屏,可收集用户在其上或附近的触摸操作(比如用户使用手指、触笔等任何适合的物体或附件在触控面板5071上或在触控面板5071附近的操作)。触控面板5071可包括触摸检测装置和触摸控制器两个部分。其中,触摸检测装置检测用户的触摸方位,并检测触摸操作带来的信号,将信号传送给触摸控制器;触摸控制器从触摸检测装置上接收触摸信息,并将它转换成触点坐标,再送给处理器510,接收处理器510发来的命令并加以执行。此外,可以采用电阻式、电容式、红外线以及表面声波等多种类型实现触控面板5071。除了触控面板5071,用户输入单元507还可以包括其他输入设备5072。具体地,其他输入设备5072可以包括但不限于物理键盘、功能键(比如音量控制按键、开关按键等)、轨迹球、鼠标、操作杆,在此不再赘述。
进一步的,触控面板5071可覆盖在显示面板5061上,当触控面板5071检测到在其上或附近的触摸操作后,传送给处理器510以确定触摸事件的类型,随后处理器510根据触摸事件的类型在显示面板5061上提供相应的视觉输出。虽然在图5中,触控面板5071与显示面板5061是作为两个独立的部件来实现电子设备的输入和输出功能,但是在某些实施例中,可以将触控面板5071与显示面板5061集成而实现电子设备的输入和输出功能,具体此处不做限定。
接口单元508为外部装置与电子设备500连接的接口。例如,外部装置可以包括有线或无线头戴式耳机端口、外部电源(或电池充电器)端口、有线或无线数据端口、存储卡端口、用于连接具有识别模块的装置的端口、音频输入/输出(I/O)端口、视频I/O端口、耳机端口等等。接口单元508可以用于接收来自外部装置的输入(例如,数据信息、电力等等)并且将接收到的输入传输到电子设备500内的一个或多个元件或者可以用于在电子设备500和外部装置之间传输数据。
存储器509可用于存储软件程序以及各种数据。存储器509可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据手机的使用所创建的数据(比如音频数据、电话本等)等。此外,存储器509可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
处理器510是电子设备的控制中心,利用各种接口和线路连接整个电子设备的各个部分,通过运行或执行存储在存储器509内的软件程序和/或模块,以及调用存储在存储器509内的数据,执行电子设备的各种功能和处理数据,从而对电子设备进行整体监控。处理器510可包括一个或多个处理单元;优选的,处理器510可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器510中。
电子设备500还可以包括给各个部件供电的电源511(比如电池),优选的,电源511可以通过电源管理系统与处理器510逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。
另外,电子设备500包括一些未示出的功能模块,在此不再赘述。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本发明各个实施例所述的方法。
如图6所示,在本发明提供的又一实施例中,还提供了一种计算机可读存储介质601,该计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述实施例中所述的针对初始模型的知识蒸馏方法。
上面结合附图对本发明的实施例进行了描述,但是本发明并不局限于上述的具体实施方式,上述的具体实施方式仅仅是示意性的,而不是限制性的,本领域的普通技术人员在本发明的启示下,在不脱离本发明宗旨和权利要求所保护的范围情况下,还可做出很多形式,均属于本发明的保护之内。
本领域普通技术人员可以意识到,结合本发明实施例中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统、装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分类部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。
所述功能如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、ROM、RAM、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (17)

1.一种针对初始模型的知识蒸馏方法,其特征在于,所述初始模型包括多层卷积层,所述卷积层配置有对应的全局平均池化层模型,多层所述卷积层与多个所述全局平均池化层模型一一对应,包括:
生成目标图像;
采用所述初始模型在第一迭代周期基于所述目标图像生成初始学生模型,并将所述初始学生模型确定为第二迭代周期的第一教师模型;
基于所述目标图像通过前向传播计算,生成针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值,并获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值;所述第一全局平均池化层模型预测值为所述全局平均池化层模型在所述第一迭代周期通过前向传播计算生成的全局平均池化层模型预测值;
获取人工标注模型标签值;
计算所述初始模型预测值和所述人工标注模型标签值之间的第一交叉熵损失,并采用所述第一交叉熵损失确定针对所述初始模型的第一最小分类损失值;
计算所述第一全局平均池化层模型预测值和所述人工标注模型标签值之间的第二交叉熵损失,并采用所述第二交叉熵损失确定针对所述全局平均池化层模型的第二最小分类损失值;所述第一最小分类损失值和所述第二最小分类损失值用于参与针对所述初始模型和所述全局平均池化层模型在所述第一迭代周期的反向传播计算;
通过在第一迭代周期反向传播计算更新所述初始模型和所述全局平均池化层模型的参数,通过所述初始模型预测值和所述第一全局平均池化层模型预测值,计算针对所述初始模型和所述全局平均池化层模型的第一分类正确率;
基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;
采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型;
将所述第一目标学生模型确定为第三迭代周期的第二教师模型;
获取针对所述第二教师模型的教师模型预测值和针对所述全局平均池化层模型的第二全局平均池化层模型预测值;所述第二全局平均池化层模型预测值为所述全局平均池化层模型在所述第二迭代周期通过前向传播计算生成的全局平均池化层模型预测值;
计算所述第一平滑软标签和所述教师模型预测值之间的第一库尔巴克莱布勒散度,并采用所述第一库尔巴克莱布勒散度确定针对所述第二教师模型的第三最小分类损失值;
计算所述第一平滑软标签和第二全局平均池化层模型预测值之间的第二库尔巴克莱布勒散度,并采用所述第二库尔巴克莱布勒散度确定针对所述全局平均池化层模型的第四最小分类损失值;所述第三最小分类损失值和所述第四最小分类损失值用于参与针对所述初始模型和所述全局平均池化层模型在所述第二迭代周期的反向传播计算;
通过在第二迭代周期反向传播计算更新所述初始模型和所述全局平均池化层模型的参数,通过所述教师模型预测值和所述第二全局平均池化层模型预测值,计算针对所述教师模型和所述全局平均池化层模型的第二分类正确率;
基于所述第二分类正确率消除标签噪声,生成第二平滑软标签;
采用所述教师模型基于所述第二平滑软标签,在所述第三迭代周期生成第二目标学生模型。
2.根据权利要求1所述的方法,其特征在于,还包括:
将最后一次迭代生成的学生模型确定为分类模型;
将所述目标图像输入所述分类模型,输出模型预测概率,并将最大的模型预测概率对应的结果确定为分类结果。
3.根据权利要求2所述的方法,其特征在于,所述基于所述目标图像通过前向传播计算,生成针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值,并获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值的步骤包括:
将所述目标图像作为输入图像,确定针对所述初始模型输出的初始模型预测概率向量,以及与针对所述初始模型预测概率向量对应的第一向量维度;
采用针对所述初始模型预测概率向量和所述第一向量维度,计算针对所述全局平均池化层模型的第一全局平均池化层模型预测概率向量;
通过所述初始模型预测概率向量确定针对所述初始模型的初始模型预测值,并通过所述第一全局平均池化层模型预测概率向量确定针对所述全局平均池化层模型的第一全局平均池化层模型预测值。
4.根据权利要求3所述的方法,其特征在于,通过如下方式计算所述初始模型预测值和所述人工标注模型标签值之间的第一交叉熵损失:
其中,为第一交叉熵损失,为类别数目,t为迭代次数,为样本的类别的人工标记值,表示第次迭代的初始模型对样本的类别的预测值。
5.根据权利要求4所述的方法,其特征在于,通过如下方式计算所述第一全局平均池化层模型预测值和所述人工标注模型标签值之间的第二交叉熵损失:
其中,为第二交叉熵损失,i表示样本下标,表示第次迭代的第层卷积层对应的全局平均池化层模型。
6.根据权利要求5所述的方法,其特征在于,还包括:
确定针对所述初始模型的初始模型权重;
采用所述第一最小分类损失值确定针对所述初始模型的初始模型偏置梯度,基于所述初始模型权重和所述初始模型偏置梯度更新模型参数;
确定针对所述全局平均池化层模型的第一全局平均池化层模型权重;
采用所述第二最小分类损失值确定针对所述全局平均池化层模型的第一全局平均池化层模型偏置梯度,基于所述第一全局平均池化层模型权重和所述第一全局平均池化层模型偏置梯度更新模型参数。
7.根据权利要求6所述的方法,其特征在于,所述获取针对所述第二教师模型的教师模型预测值和针对所述全局平均池化层模型的第二全局平均池化层模型预测值的步骤包括:
将所述目标图像作为输入图像,确定针对所述第二教师模型输出的第二教师模型预测概率向量,以及与针对所述第二教师模型预测概率向量对应的第二向量维度;
采用针对所述第二教师模型预测概率向量和所述第二向量维度,计算针对所述全局平均池化层模型的第二全局平均池化层模型预测概率向量;
通过所述第二教师模型预测概率向量确定获取针对所述第二教师模型的教师模型预测值,并通过所述第二全局平均池化层模型预测概率向量确定针对所述全局平均池化层模型的第二全局平均池化层模型预测值。
8.根据权利要求7所述的方法,其特征在于,通过如下方式计算所述第一平滑软标签和所述教师模型预测值之间的第一库尔巴克莱布勒散度:
其中,为第一库尔巴克莱布勒散度,i表示样本下标,t为迭代次数,为类别数目,为第-1次迭代的教师模型输出的平滑软标签,表示第次迭代的初始模型对样本的类别的预测值。
9.根据权利要求8所述的方法,其特征在于,通过如下方式计算所述第一平滑软标签和第二全局平均池化层模型预测值之间的第二库尔巴克莱布勒散度:
其中,为第二库尔巴克莱布勒散度,表示第次迭代的第层卷积层对应的全局平均池化层模型。
10.根据权利要求9所述的方法,其特征在于,还包括:
确定针对所述第二教师模型的第二教师模型权重;
采用所述第三最小分类损失值确定针对所述第二教师模型的第二教师模型偏置梯度,基于所述第二教师模型权重和所述第二教师模型偏置梯度更新模型参数;
确定针对所述全局平均池化层模型的第二全局平均池化层模型权重;
采用所述第四最小分类损失值确定针对所述全局平均池化层模型的第二全局平均池化层模型偏置梯度,基于所述第二全局平均池化层模型权重和所述第二全局平均池化层模型偏置梯度更新模型参数。
11.根据权利要求1所述的方法,其特征在于,所述生成目标图像的步骤包括:
获取初始图像集;
确定目标亮度和目标尺寸;
基于所述目标亮度和所述目标尺寸,对所述初始图像集进行归一化操作,生成目标图像。
12.根据权利要求11所述的方法,其特征在于,所述归一化操作包括均值归一化、方差归一化和阈值归一化。
13.根据权利要求11所述的方法,其特征在于,还包括:
基于缺失值对所述初始图像集执行数据清洗操作;和/或,
基于异常值对所述初始图像集执行数据清洗操作;和/或,
基于噪声数据对所述初始图像集执行数据清洗操作。
14.根据权利要求11所述的方法,其特征在于,还包括:
对所述初始图像集执行数据抽样操作;所述数据抽样操作包括:随机抽样,和/或,分层抽样,和/或,过采样和欠采样。
15.一种针对初始模型的知识蒸馏装置,其特征在于,所述初始模型包括卷积层,所述卷积层配置有对应的全局平均池化层模型,多层所述卷积层与多个所述全局平均池化层模型一一对应,包括:
目标图像生成模块,用于生成目标图像;
第一教师模型确定模块,用于采用所述初始模型在第一迭代周期基于所述目标图像生成初始学生模型,并将所述初始学生模型确定为第二迭代周期的第一教师模型;
预测值获取模块,用于基于所述目标图像通过前向传播计算,生成针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值,并获取针对所述初始模型的初始模型预测值,以及,针对所述全局平均池化层模型的第一全局平均池化层模型预测值;所述第一全局平均池化层模型预测值为所述全局平均池化层模型在所述第一迭代周期通过前向传播计算生成的全局平均池化层模型预测值;获取人工标注模型标签值;计算所述初始模型预测值和所述人工标注模型标签值之间的第一交叉熵损失,并采用所述第一交叉熵损失确定针对所述初始模型的第一最小分类损失值;计算所述第一全局平均池化层模型预测值和所述人工标注模型标签值之间的第二交叉熵损失,并采用所述第二交叉熵损失确定针对所述全局平均池化层模型的第二最小分类损失值;所述第一最小分类损失值和所述第二最小分类损失值用于参与针对所述初始模型和所述全局平均池化层模型在所述第一迭代周期的反向传播计算;
第一分类正确率计算模块,用于通过在第一迭代周期反向传播计算更新所述初始模型和所述全局平均池化层模型的参数,通过所述初始模型预测值和所述第一全局平均池化层模型预测值,计算针对所述初始模型和所述全局平均池化层模型的第一分类正确率;
第一平滑软标签生成模块,用于基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;
第一目标学生模型生成模块,用于采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型;基于所述第一分类正确率消除标签噪声,生成第一平滑软标签;采用第一教师模型基于所述第一平滑软标签,在所述第二迭代周期生成第一目标学生模型;将所述第一目标学生模型确定为第三迭代周期的第二教师模型;获取针对所述第二教师模型的教师模型预测值和针对所述全局平均池化层模型的第二全局平均池化层模型预测值;所述第二全局平均池化层模型预测值为所述全局平均池化层模型在所述第二迭代周期通过前向传播计算生成的全局平均池化层模型预测值;计算所述第一平滑软标签和所述教师模型预测值之间的第一库尔巴克莱布勒散度,并采用所述第一库尔巴克莱布勒散度确定针对所述第二教师模型的第三最小分类损失值;计算所述第一平滑软标签和第二全局平均池化层模型预测值之间的第二库尔巴克莱布勒散度,并采用所述第二库尔巴克莱布勒散度确定针对所述全局平均池化层模型的第四最小分类损失值;所述第三最小分类损失值和所述第四最小分类损失值用于参与针对所述初始模型和所述全局平均池化层模型在所述第二迭代周期的反向传播计算;通过在第二迭代周期反向传播计算更新所述初始模型和所述全局平均池化层模型的参数,通过所述教师模型预测值和所述第二全局平均池化层模型预测值,计算针对所述教师模型和所述全局平均池化层模型的第二分类正确率;基于所述第二分类正确率消除标签噪声,生成第二平滑软标签;采用所述教师模型基于所述第二平滑软标签,在所述第三迭代周期生成第二目标学生模型。
16.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,所述处理器、所述通信接口以及所述存储器通过所述通信总线完成相互间的通信;
所述存储器,用于存放计算机程序;
所述处理器,用于执行存储器上所存放的程序时,实现如权利要求1-14任一项所述的方法。
17.一种计算机可读存储介质,其上存储有指令,当由一个或多个处理器执行时,使得所述处理器执行如权利要求1-14任一项所述的方法。
CN202311481966.7A 2023-11-08 2023-11-08 一种针对初始模型的知识蒸馏方法和装置 Active CN117237742B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311481966.7A CN117237742B (zh) 2023-11-08 2023-11-08 一种针对初始模型的知识蒸馏方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311481966.7A CN117237742B (zh) 2023-11-08 2023-11-08 一种针对初始模型的知识蒸馏方法和装置

Publications (2)

Publication Number Publication Date
CN117237742A CN117237742A (zh) 2023-12-15
CN117237742B true CN117237742B (zh) 2024-02-20

Family

ID=89086354

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311481966.7A Active CN117237742B (zh) 2023-11-08 2023-11-08 一种针对初始模型的知识蒸馏方法和装置

Country Status (1)

Country Link
CN (1) CN117237742B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118101339B (zh) * 2024-04-23 2024-07-23 山东科技大学 一种应对物联网隐私保护的联邦知识蒸馏方法

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
KR20220096099A (ko) * 2020-12-30 2022-07-07 성균관대학교산학협력단 지식 증류에서 총 cam 정보를 이용한 교사 지원 어텐션 전달의 학습 방법 및 장치
CN115994611A (zh) * 2022-10-25 2023-04-21 京东城市(北京)数字科技有限公司 类别预测模型的训练方法、预测方法、设备和存储介质
CN116168439A (zh) * 2023-03-03 2023-05-26 中南大学 一种轻量级唇语识别方法及相关设备
CN116543250A (zh) * 2023-03-29 2023-08-04 西安电子科技大学 一种基于类注意力传输的模型压缩方法

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
KR20220096099A (ko) * 2020-12-30 2022-07-07 성균관대학교산학협력단 지식 증류에서 총 cam 정보를 이용한 교사 지원 어텐션 전달의 학습 방법 및 장치
CN115994611A (zh) * 2022-10-25 2023-04-21 京东城市(北京)数字科技有限公司 类别预测模型的训练方法、预测方法、设备和存储介质
CN116168439A (zh) * 2023-03-03 2023-05-26 中南大学 一种轻量级唇语识别方法及相关设备
CN116543250A (zh) * 2023-03-29 2023-08-04 西安电子科技大学 一种基于类注意力传输的模型压缩方法

Also Published As

Publication number Publication date
CN117237742A (zh) 2023-12-15

Similar Documents

Publication Publication Date Title
CN109919251B (zh) 一种基于图像的目标检测方法、模型训练的方法及装置
CN110009052B (zh) 一种图像识别的方法、图像识别模型训练的方法及装置
CN109086709B (zh) 特征提取模型训练方法、装置及存储介质
WO2020177582A1 (zh) 视频合成的方法、模型训练的方法、设备及存储介质
CN108304758B (zh) 人脸特征点跟踪方法及装置
CN110738211B (zh) 一种对象检测的方法、相关装置以及设备
CN112990390B (zh) 一种图像识别模型的训练方法、图像识别的方法及装置
CN111797288B (zh) 数据筛选方法、装置、存储介质及电子设备
CN113284142B (zh) 图像检测方法、装置、计算机可读存储介质及计算机设备
CN114418069B (zh) 一种编码器的训练方法、装置及存储介质
CN117237742B (zh) 一种针对初始模型的知识蒸馏方法和装置
CN110516113B (zh) 一种视频分类的方法、视频分类模型训练的方法及装置
WO2024148870A1 (zh) 一种调频方法、装置、电子设备及可读存储介质
CN112184548A (zh) 图像超分辨率方法、装置、设备及存储介质
CN110162956A (zh) 确定关联账户的方法和装置
CN111738100B (zh) 一种基于口型的语音识别方法及终端设备
CN114612830A (zh) 一种花屏图像的识别方法、装置、设备以及存储介质
CN112488157B (zh) 一种对话状态追踪方法、装置、电子设备及存储介质
CN116486463B (zh) 图像处理方法、相关装置及存储介质
CN117523632A (zh) 一种面瘫等级分析方法和相关装置
WO2023137923A1 (zh) 基于姿态指导的行人重识别方法、装置、设备及存储介质
CN117541770A (zh) 数据增强方法、装置及电子设备
CN110119383A (zh) 一种文件管理方法及终端设备
CN112150174B (zh) 一种广告配图方法、装置及电子设备
CN114943639B (zh) 图像获取方法、相关装置及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant