CN114359563B - 模型训练方法、装置、计算机设备和存储介质 - Google Patents
模型训练方法、装置、计算机设备和存储介质 Download PDFInfo
- Publication number
- CN114359563B CN114359563B CN202210274888.2A CN202210274888A CN114359563B CN 114359563 B CN114359563 B CN 114359563B CN 202210274888 A CN202210274888 A CN 202210274888A CN 114359563 B CN114359563 B CN 114359563B
- Authority
- CN
- China
- Prior art keywords
- pixel
- sample image
- category
- trained
- image
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Landscapes
- Image Analysis (AREA)
Abstract
本申请涉及一种模型训练方法、装置、计算机设备、存储介质和计算机程序产品。所述方法包括:获取已训练的教师模型对样本图像中各像素的类别预测结果;根据已训练的教师模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息量;在各预设图像类别下,基于样本图像中各像素的信息量、已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到待训练的学生模型的目标损失函数;根据目标损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型;训练完成的学生模型用于对输入的图像进行语义分割。采用本方法能够提升学生模型整体的预测准确性。
Description
技术领域
本申请涉及计算机技术领域,特别是涉及一种模型训练方法、装置、计算机设备、存储介质和计算机程序产品。
背景技术
知识蒸馏技术是在模型训练过程中,使用一个规模较大的模型作为老师模型进行训练,提取出图像样本中的特征信息,然后将特征信息传递给规模较小的学生模型,使得规模较小的学生模型不仅速度较快,还能借助特征信息提升模型性能。
然而,传统的知识蒸馏技术是直接将蒸馏损失函数应用在所有图像样本上,并没有考虑图像样本之间的差异性,差异性包括图像样本的类别数量和图形样本包含的信息量,使得在模型训练过程中模型会更倾向于信息量较少的多数类样本,而忽视信息量较大的少数类样本,造成学生模型在信息量较大的少数类样本上的预测准确性较低。
发明内容
基于此,有必要针对上述技术问题,提供一种能够提升学生模型预测准确率的模型训练方法、装置、计算机设备、计算机可读存储介质和计算机程序产品。
第一方面,本申请提供了一种模型训练方法。所述方法包括:
获取已训练的教师模型对样本图像中各像素的类别预测结果;
根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;
在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;
根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割。
在其中一个实施例中,各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数包括:
在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,确定所述样本图像中各像素的信息量的权重;
根据各预设图像类别下所述样本图像中各像素的信息量和所述样本图像中各像素的信息量的权重,得到所述待训练的学生模型的目标损失函数。
在其中一个实施例中,在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,确定所述样本图像中各像素的信息量的权重,包括:
根据所述样本图像中各像素的类别预测结果,从所述各预设图像类别中确定出所述样本图像中各像素所属的图像类别;
根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息散度;所述信息散度表示所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果之间的距离;
在所述各预设图像类别下,根据所述样本图像中各像素的信息散度,依次确定所述样本图像中所属的图像类别与所述预设图像类别相同的像素的信息量的权重,得到所述样本图像中各像素的信息量的权重。
在其中一个实施例中,根据所述样本图像中各像素的信息量和所述样本图像中各像素的信息量的权重,得到所述待训练的学生模型的目标损失函数,包括:
在所述各预设图像类别下,分别根据样本图像中所属的图像类别与所述预设图像类别相同的像素的信息量和所述与所述预设图像类别相同的像素的信息量的权重,确定所述样本图像在所述各预设图像类别下的总信息量;
根据所述样本图像在所述各预设图像类别下的总信息量之和的平均值,得到所述待训练的学生模型的目标损失函数。
在其中一个实施例中,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量,包括:
根据所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素在所述各预设图像类别下的类别信息量;
根据所述各像素在所述各预设图像类别下的类别信息量,得到所述各像素的信息量。
在其中一个实施例中,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量,还包括:
分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果;
分别根据所述各像素的目标预测结果,得到所述样本图像中各像素的信息量。
在其中一个实施例中,分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果,包括:
分别从所述各像素的类别预测结果中,筛选出类别预测概率最大的类别预测结果,作为所述各像素的目标预测结果。
在其中一个实施例中,分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果,包括:
分别从所述各像素的类别预测结果中,筛选出类别预测概率最大和类别预测概率第二大的类别预测结果,作为所述各像素的目标预测结果。
在其中一个实施例中,根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型,包括:
获取所述待训练的学生模型的初始损失函数;
根据所述初始损失函数和所述目标损失函数,得到总损失函数;
根据所述总损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型。
在其中一个实施例中,根据所述初始损失函数和所述目标损失函数,得到总损失函数,包括:
将所述初始损失函数与所述目标损失函数进行相加,得到所述总损失函数。
第二方面,本申请还提供了一种模型训练装置。所述装置包括:
像素预测模块,用于获取已训练的教师模型对样本图像中各像素的类别预测结果;
信息提取模块,用于根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;
函数获取模块,用于在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;
模型获取模块,用于根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割。
第三方面,本申请还提供了一种计算机设备。所述计算机设备包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:
获取已训练的教师模型对样本图像中各像素的类别预测结果;
根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;
在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;
根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割。
第四方面,本申请还提供了一种计算机可读存储介质。所述计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现以下步骤:
获取已训练的教师模型对样本图像中各像素的类别预测结果;
根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;
在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;
根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割。
第五方面,本申请还提供了一种计算机程序产品。所述计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现以下步骤:
获取已训练的教师模型对样本图像中各像素的类别预测结果;
根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;
在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;
根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割。
上述模型训练方法、装置、计算机设备、存储介质和计算机程序产品,通过根据已训练的教师模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息量;在各预设图像类别下,基于样本图像中各像素的信息量、已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到待训练的学生模型的目标损失函数;进而根据目标损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型;训练完成的学生模型用于对输入的图像进行语义分割。采用本方法,通过样本图像中各像素的信息量来得到待训练的学生模型的目标损失函数,目标损失函数能够使学生模型在训练过程中更关注样本图像中占有较高信息量的少数类,提高了学生模型在信息量较大的像素上的预测准确性,从而提升了学生模型整体的预测准确性。
附图说明
图1为一个实施例中模型训练方法的流程示意图;
图2为一个实施例中获取待训练的学生模型的目标损失函数步骤的流程示意图;
图3为另一个实施例中模型训练方法的流程示意图;
图4为一个实施例中模型训练装置的结构框图;
图5为一个实施例中计算机设备的内部结构图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
在一个实施例中,如图1所示,提供了一种模型训练方法,本实施例以该方法应用于服务器进行举例说明,可以理解的是,该方法也可以应用于服务器,还可以应用于包括终端和服务器的系统,并通过终端和服务器的交互实现。其中,样本图像可以是预先存储在服务器中的,也可以是终端设备发送到服务器中的。本实施例中,该方法包括以下步骤:
步骤S101,获取已训练的教师模型对样本图像中各像素的类别预测结果。
其中,教师模型和学生模型是知识蒸馏场景下对不同分支模型的描述,教师模型的参数的数量多于学生模型的参数的数量。
实际应用中,知识蒸馏作为一种重要的模型压缩手段,可以将教师模型在知识机器学习过程挖掘得到的知识,迁移到学生模型中,以使得学生模型用更少的空间复杂度和训练时间得到与教师模型类似的训练效果,并且学生模型的拟合能力能够逼近甚至超过教师模型。
其中,像素的类别预测结果包括像素在各预设图像类别下的预测结果。
举例说明,假设预设图像类别包括类别1、类别2和类别3,则像素的类别预测结果会包括预测像素分别在类别1、类别2和类别3下的概率。
具体地,服务器获取已经预先训练完成的教师模型和样本图像,将样本图像输入到已训练教师模型中,通过已训练的教师模型对样本图像中各像素的类别进行预测,得到已训练的教师模型输出的样本图像中各像素的类别预测结果。
步骤S102,根据已训练的教师模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息量。
其中,信息量是指样本图像中的像素包含的信息。信息论中通过事件的概率来描述不确定性,本实施例中像素包含的信息是通过类别预测概率来描述不确定性。
其中,样本图像中易分割、预测简单的像素,其包含的信息量较小,样本图像中难分割、预测困难的像素,其包含的信息量较大。
需要说明的是,由于待训练的学生模型还缺乏训练,使得待训练的学生模型的预测结果并不可靠,因此,使用已训练的教师模型来获取样本图像中各像素的类别预测结果。
具体地,服务器根据已训练的教师模型对样本图像中各像素的类别预测结果,通过信息论中的信息量评价指标来计算样本图像中每个像素的不确定性,即像素的信息量;由此,服务器在得到信息量之后,将信息量作为处理依据执行后续的目标损失函数获取步骤。
步骤S103,在各预设图像类别下,基于样本图像中各像素的信息量、已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到待训练的学生模型的目标损失函数。
其中,目标损失函数用于调整模型训练过程中样本图像的类别权重。
具体地,服务器根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果之间的距离,可以构建初步的损失函数,再根据样本图像中各像素的信息量,确定样本图像中各像素在各预设图像类别下的权重,进而根据样本图像中各像素在各预设类别下的权重和各预设类别下的初步的损失函数,得到待训练的学生模型的目标损失函数。
实际应用中,初步的损失函数的构建可以通过信息散度(Kullback-Leiblerdivergence,KLD)、余弦相似度、交叉熵等技术实现。
步骤S104,根据目标损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型;训练完成的学生模型用于对输入的图像进行语义分割。
其中,教师模型和学生模型可以是用于语义分割的网络模型。
其中,语义分割是指对样本图像进行像素级别的预测或分类,为达到较优的预测性能,往往需要较大的参数量来构建出准确率较高的模型,即教师模型,同时,这也会使得教师模型的预测速度较慢,无法适用于移动场景中的速度和时间需求;而训练完成的学生模型,由于自身参数数量较小的网络结构,使得学生模型在保持较快的处理速度的同时,还通过教师模型迁移过来的知识,提升了自身的预测性能,从而比训练完成的学生模型教师模型拥有更强的处理速度和处理准确率。
具体地,服务器将目标损失函数作为待训练的学生模型的训练过程的损失函数,对待训练的学生模型进行迭代训练,当学生模型满足预设的训练条件时,得到训练完成的学生模型。由此,服务器在得到训练完成的学生模型之后,将训练完成的学生模型作为处理依据执行后续的语义分割步骤。通过训练完成的学生模型对输入的图像进行语义分割,得到语义分割结果,将语义分割结果存储在服务器中或者发送给终端设备。
上述模型训练方法中,通过根据已训练的教师模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息量;在各预设图像类别下,基于样本图像中各像素的信息量、已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到待训练的学生模型的目标损失函数;进而根据目标损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型;训练完成的学生模型用于对输入的图像进行语义分割。采用本方法,通过样本图像中各像素的信息量来得到待训练的学生模型的目标损失函数,能够使学生模型在训练过程中更关注样本图像中占有较高信息量的少数类,提高了学生模型在信息量较大的像素上的预测准确性,从而提升了学生模型整体的预测准确性。
在一个实施例中,如图2所示,在各预设图像类别下,基于样本图像中各像素的信息量、已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到待训练的学生模型的目标损失函数,具体包括如下内容:
步骤S201,在各预设图像类别下,根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,确定样本图像中各像素的信息量的权重;
步骤S202,根据各预设图像类别下样本图像中各像素的信息量和样本图像中各像素的信息量的权重,得到待训练的学生模型的目标损失函数。
其中,像素的信息量的权重用于在学生模型的训练过程中,赋予样本图像中信息量更高的像素更多的关注。
举例说明,语义分割场景中会面临像素的信息量不均衡的问题,例如,样本图像中包含有100个像素点,该样本图像的预设图像类别包括A、B、C三类,其中,C类中90%以上的像素点属于较容易分割区域,10%像素点的区域较难分割区域,由于较容易分割区域在C类中占比较大,使得较容易分割区域主导了较难分割区域的梯度,导致语义分割模型无法关注到各预设图像类别中信息量较大的较难分割区域;通过调整像素的信息量的权重,能够赋予信息量更高的像素更多的关注,即赋予样本图像中较难分割区域更多的关注。
进一步地,语义分割场景中还会面临像素的类别不均的问题,例如,样本图像中包含有100个像素点,该样本图像的预设图像类别包括A、B、C三类,其中,属于A类的像素点有68个,属于B类的像素点有15个,属于C类的像素点有17个;在该情况下,会导致模型在训练过程更多的关注在A类像素点的训练上,而忽略了B类和C类像素点;通过调整像素的权重,能够赋予少数类更高的关注,即赋予B类和C类像素点更高的关注。
具体地,服务器分别在各预设图像类别下,根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,确定样本图像中各像素的信息量在对应图像类别下的权重,进而根据各预设图像类别下样本图像中各像素的信息量,以及样本图像中各像素的信息量在对应图像类别下的权重,得到待训练的学生模型的目标损失函数。
在实际的语义分割场景中,语义分割是一个密集分类任务,物体的不同部位的分类难易程度亦有所不同,样本图像中大部分都是相对较简单、易分割的部分,其包含的信息量亦较少,所以蒸馏损失函数基本上是由简单的部分占主导作用,使得梯度较小,而样本图像中真正富含信息的占小部分的像素被压制。由于占比更多的像素很有可能是背景图像,占比更少的像素是真正需要精确识别的人物、物品图像,因此本公开不仅考虑了教师模型和学生模型在预测结果上的差异性,还通过信息量来调整各预设图像类别下的像素的权重,使得学生模型在信息量更高的像素上具有更高的预测准确率。
本实施例中,在各预设图像类别下,根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,确定样本图像中各像素的信息量的权重,以使样本图像中各像素的信息量具有不同的权重;通过根据各预设图像类别下样本图像中各像素的信息量和样本图像中各像素的信息量的权重,得到待训练的学生模型的目标损失函数,进而实现了在目标损失函数中考虑样本图像中每个像素之间的差异性,从而提高了学生模型的预测准确性。
在一个实施例中,上述步骤S201,在各预设图像类别下,根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,确定样本图像中各像素的信息量的权重,具体包括如下内容:根据样本图像中各像素的类别预测结果,从各预设图像类别中确定出样本图像中各像素所属的图像类别;根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息散度;信息散度表示已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果之间的距离;在各预设图像类别下,根据样本图像中各像素的信息散度,依次确定样本图像中所属的图像类别与预设图像类别相同的像素的信息量的权重,得到样本图像中各像素的信息量的权重。
具体地,服务器检测样本图像中各像素的类别预测结果,得到各像素的类别预测结果中数值最大的预测结果,从各预设图像类别中,确定出数值最大的预测结果所属的图像类别,作为样本图像中像素所属的图像类别;例如,样本图像中一像素的类别预测结果为(95%,2%,3%),即该像素属于图像类别1的概率为95%,该像素属于图像类别2的概率为2%,该像素属于图像类别3的概率为3%,因为95%>3%>2%,所以该像素属于图像类别1。根据信息散度可以通过已训练的教师模型对样本图像中各像素的类别预测结果的信息熵和待训练的学生模型对样本图像中各像素的类别预测结果的信息熵之间的差值,得到样本图像中各像素的信息散度;分别在各预设图像类别下,依次判断样本图像中各像素所属的图像类别与预设图像类别是否相同,若样本图像中像素所属的图像类别与预设图像类别相同,则根据该像素的信息散度,确定该像素的信息量的权重;若样本图像中像素所属的图像类别与预设图像类别不相同,则继续判断样本图像中下一个像素所属的图像类别与预设图像类别是否相同,直到得到样本图像中各像素的信息量的权重。
举例说明,预设图像类别包括图像类别1、图像类别2和图像类别3,则先确定样本图像中属于图像类别1的像素的信息量的权重,再确定样本图像中属于图像类别2的像素的信息量的权重,最后确定样本图像中属于图像类别3的像素的信息量的权重,最终得到样本图像中各像素的信息量的权重。
本实施例中,服务器根据样本图像中各像素的类别预测结果,从各预设图像类别中确定出样本图像中各像素所属的图像类别;根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息散度;进而在各预设图像类别下,根据样本图像中各像素的信息散度,依次确定样本图像中所属的图像类别与预设图像类别相同的像素的信息量的权重,得到样本图像中各像素的信息量的权重,从而实现了获取样本图像中各像素的信息量的权重,考虑了样本图像中各像素的信息量。
在一个实施例中,上述步骤S202,根据样本图像中各像素的信息量和样本图像中各像素的信息量的权重,得到待训练的学生模型的目标损失函数,具体包括如下内容:在各预设图像类别下,分别根据样本图像中所属的图像类别与预设图像类别相同的像素的信息量和与预设图像类别相同的像素的信息量的权重,确定样本图像在各预设图像类别下的总信息量;根据样本图像在各预设图像类别下的总信息量之和的平均值,得到待训练的学生模型的目标损失函数。
其中,总信息量是指样本图像中所有像素的信息量之和。
具体地,服务器分别在各预设图像类别下,依次判断样本图像中各像素所属的图像类别与预设图像类别是否相同,若样本图像中像素所属的图像类别与预设图像类别不相同,则将该像素的信息量和像素的信息量的权重设置为0;若样本图像中像素所属的图像类别与预设图像类别相同,则根据样本图像中像素的信息量和像素的信息量的权重,确定样本图像在各预设图像类别下的总信息量。将样本图像分别在各预设图像类别下的总信息量相加,得到样本图像在所有预设图像类别下的总信息量之和,根据预设图像类别的数量,获取样本图像在各预设图像类别下的总信息量之和的平均值,根据样本图像在各预设图像类别下的总信息量之和的平均值,得到待训练的学生模型的目标损失函数。
其中,K表示为预设图像类别的数量,H表示为样本图像的高,W表示为样本图像的
宽,表示为样本图像中像素点x的信息量,表示为像素x的信息散度,表示为待训练的学生模型对样本图像中像素x的类别预测结果,表示为已训练的教师
模型对样本图像中像素x的类别预测结果,1(y x =k)表示像素x所属的图像类别是否等于预
设图像类别k,当像素x所属的图像类别等于预设图像类别k时,则1(y x =k)为1,否则1(y x =k)
为0。
本实施例中,服务器在各预设图像类别下,分别根据样本图像中所属的图像类别与预设图像类别相同的像素的信息量和与预设图像类别相同的像素的信息量的权重,确定样本图像在各预设图像类别下的总信息量,根据总信息量确定待训练的学生模型的目标损失函数。采用本方法,分别根据各预设图像类别、各像素的信息量和信息量的权重来确定目标损失函数,不仅实现了通过像素的信息量来调整像素的权重,还实现了通过每个预设图像类别的总信息量来调整各预设图像类别之间的权重,解决了知识蒸馏在语义分割场景中的样本图像的像素信息量不均和像素类别不均的问题,从而提升了学生模型整体的预测准确性。
在一个实施例中,根据已训练的教师模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息量,具体包括如下内容:根据样本图像中各像素的类别预测结果,得到样本图像中各像素在各预设图像类别下的类别信息量;根据各像素在各预设图像类别下的类别信息量,得到各像素的信息量。
其中,信息量与各预设图像类别下的类别预测结果分布相关,预设图像类别下的类别预测结果的分布越极端,即类别预测结果中最大预测概率的数值越大,则像素的信息量越小;反之,类别预测结果的分布越均衡,即类别预测结果中最大预测概率的数值越小,则像素的信息量越大。
具体地,服务器在各预设图像类别下,对样本图像中各像素的类别预测结果的对数进行求负,得到样本图像中各像素在各预设图像类别下的类别信息量,将各像素在各预设图像类别下的类别信息量进行相加,得到各像素的信息量。
也就是说,先分别确定像素在各预设图像类别下的类别信息量,再将所有图像类别的类别信息量之和,作为像素的信息量。
本实施例中,服务器根据样本图像中各像素的类别预测结果,得到样本图像中各像素在各预设图像类别下的类别信息量,进而根据各像素在各预设图像类别下的类别信息量,得到各像素的信息量,从而实现基于样本图像中各像素的类别预测结果准确的确定各像素的信息量。
在一个实施例中,根据已训练的教师模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息量,还包括:分别从各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为各像素的目标预测结果;分别根据各像素的目标预测结果,得到样本图像中各像素的信息量。
具体地,服务器根据预设的信息量度量方式,确定第一预设条件,分别从各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为各像素的目标预测结果;然后分别根据各像素的目标预测结果和预设的信息量度量方式,确定样本图像中各像素的信息量。
本实施例中,服务器分别从各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为各像素的目标预测结果,根据目标预测结果,确定像素的信息量,从而实现基于样本图像中各像素的类别预测结果准确的确定各像素的信息量。
在一个实施例中,分别从各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为各像素的目标预测结果,包括:分别从各像素的类别预测结果中,筛选出类别预测概率最大的类别预测结果,作为各像素的目标预测结果。
其中,类别预测概率是指像素属于各预设图像类别的概率。
举例说明,样本图像中某一像素的类别预测结果为(95%,2%,3%),其中,类别预测概率是指95%、2%、3%,而95%就是类别预测概率最大的类别预测结果。
具体地,服务器还可以根据最低置信度(Least Confident)来确定像素的信息量,将设置第一预设条件设置为类别预测概率最大,从各像素的类别预测结果中,筛选出类别预测概率最大的类别预测结果,作为各像素的目标预测结果,即第一预设条件为类别预测概率最大的类别预测结果;此时,目标预测结果的概率越大,像素的信息量越小,反之,目标预测结果的概率越小,像素的信息量越大。
在本实施例中,服务器将各像素的类别预测结果中类别预测概率最大的,作为各像素的目标预测结果,进而根据目标预测结果,得到各像素的信息量,从而实现基于样本图像中各像素的类别预测结果准确的确定各像素的信息量。
在一个实施例中,分别从各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为各像素的目标预测结果,包括:分别从各像素的类别预测结果中,筛选出类别预测概率最大和类别预测概率第二大的类别预测结果,作为各像素的目标预测结果。
具体地,服务器还可以根据边界采样(Margin Sampling)来确定像素的信息量,从各像素的类别预测结果中,筛选出类别预测概率最大的类别预测结果和类别预测概率第二大的类别预测结果,作为各像素的目标预测结果,即第一预设条件为类别预测概率最大的类别预测结果和类别预测概率第二大的类别预测结果;根据概率数值最大的类别预测结果和概率数值第二大的类别预测结果之间的差值,确定像素的信息量,差值越大,像素的信息量越小,反之,差值越小,像素的信息量越大。
举例说明,样本图像中某一像素的类别预测结果为(95%,2%,3%),其中,类别预测概率是指95%、2%、3%,而95%和3%分别是指类别预测概率最大的类别预测结果和类别预测概率第二大的类别预测结果。
在本实施例中,服务器根据像素的类别预测结果中类别预测概率最大的类别预测结果和类别预测概率第二大的类别预测结果之间的差值,确定像素的信息量,从而实现基于样本图像中各像素的类别预测结果准确的确定各像素的信息量。
在一个实施例中,根据目标损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型,包括:获取待训练的学生模型的初始损失函数;根据初始损失函数和目标损失函数,得到总损失函数;根据总损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型。
其中,初始损失函数可以是待训练的学生模型的交叉熵、信息散度、余弦相似度等损失函数。
具体地,获取待训练的学生模型的初始损失函数和损失权重,将损失权重作为目标损失函数的目标损失权重,根据初始损失函数、目标损失函数和目标损失函数的目标损失权重,得到总损失函数;根据总损失函数,对待训练的学生模型进行迭代训练。
进一步地,当检测到待训练的学生模型的预测结果满足第二预设条件时,将预测结果满足第二预设条件的学生模型作为训练完成的学生模型;或者,当待训练的学生模型经过迭代训练的次数达到预设训练次数时,将达到预设训练次数时的学生模型作为训练完成的学生模型。其中,当目标损失函数为1时,训练完成的学生模型通常具有较好的预测效果,当然目标损失函数也可以根据实际应用场景取其他数值。
本实施例中,服务器通过获取待训练的学生模型的初始损失函数,根据初始损失函数和目标损失函数,得到总损失函数,进而根据总损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型,从而基于训练完成的学生模型可以对输入的样本图像进行语义分割处理。
在一个实施例中,根据初始损失函数和目标损失函数,得到总损失函数,具体包括如下内容:将初始损失函数与目标损失函数进行相加,得到总损失函数。
具体地,将目标损失函数和目标损失函数的目标损失权重进行相乘,得到相乘后损失函数,将相乘后损失函数与初始损失函数进行相加,得到总损失函数。
举例说明,总损失函数L可以通过如下公式得到:
在实际应用中,教师模型可以是101层的深度残差网络(Deep residual network,ResNet),即ResNet-101,学生模型可以是18层的深度残差网络,即ResNet-18。基于ResNet-101教师模型和ResNet-18学生模型对本公开提出的模型训练方法进行测试,选择语义分割场景中常使用的城市景观(cityscapes)数据集和PASCAL Context数据集;在教师模型、学生模型、教师模型和学生模型的预设超参数都相同的情况下,分别采用本公开中提出的模型训练方法与语义分割场景中的传统技术KD(Distilling the knowledge in a neuralnetwork)、SDK(Structured knowledge distillation for dense prediction)、IFVD(Intra-class feature variation distillation for semantic segmentation)、CSCACE(Knowledge distillation for semantic segmentation using channel and spatialcorrelations and adaptive crossentropy)和KA(Knowledge adaptation forefficient semantic segmentation),对ResNet-18学生模型进行训练,得到本公开提出的模型训练方法得到的训练完成的学生模型的预测准确率与其他方法得到的训练完成的学生模型的预测准确率,预测准确率的对比结果如表1所示。
表1模型的预测准确率(%)
由表1可知,通过本公开中提出的模型训练方法得到的训练完成的学生模型,在城市景观(cityscapes)数据集和PASCAL Context数据集上,比传统技术KD、SDK、IFVD、CSCACE和KA得到的训练完成的学生模型具有更高的准确率。因此,本公开中提出的模型训练方法,在不影响学生模型的预测速度、不增加学生模型的参数、不改变学生模型的预设超参数的同时,还实现了学生模型的预测准确率的有效提升。
本实施例中,服务器将初始损失函数和目标损失函数进行相加,得到总损失函数,进而根据总损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型,从而基于训练完成的学生模型可以对输入的样本图像进行语义分割处理。
在一个实施例中,如图3所示,提供了另一种模型训练方法,以该方法应用于服务器为例进行说明,包括以下步骤:
步骤S301,获取已训练的教师模型对样本图像中各像素的类别预测结果。
步骤S302,根据样本图像中各像素的类别预测结果,得到样本图像中各像素在各预设图像类别下的类别信息量;根据各像素在各预设图像类别下的类别信息量,得到各像素的信息量。
步骤S303,分别从各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为各像素的目标预测结果;分别根据各像素的目标预测结果,得到样本图像中各像素的信息量;
其中,步骤S303还包括步骤S303-1,分别从各像素的类别预测结果中,筛选出类别预测概率最大的类别预测结果,作为各像素的目标预测结果;步骤S303还包括步骤S303-2,分别从各像素的类别预测结果中,筛选出类别预测概率最大和类别预测概率第二大的类别预测结果,作为各像素的目标预测结果。
步骤S304,根据样本图像中各像素的类别预测结果,从各预设图像类别中确定出样本图像中各像素所属的图像类别。
步骤S305,根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息散度。
步骤S306,在各预设图像类别下,根据样本图像中各像素的信息散度,依次确定样本图像中所属的图像类别与预设图像类别相同的像素的信息量的权重,得到样本图像中各像素的信息量的权重。
步骤S307,在各预设图像类别下,分别根据样本图像中所属的图像类别与预设图像类别相同的像素的信息量和与预设图像类别相同的像素的信息量的权重,确定样本图像在各预设图像类别下的总信息量。
步骤S308,根据样本图像在各预设图像类别下的总信息量之和的平均值,得到待训练的学生模型的目标损失函数。
步骤S309,获取待训练的学生模型的初始损失函数;根据初始损失函数和目标损失函数,得到总损失函数;
其中,步骤S309还包括步骤S309-1,将初始损失函数与目标损失函数进行相加,得到总损失函数。
步骤S310,根据总损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型。
上述模型训练方法,能够实现以下有益效果:分别根据各预设图像类别、各像素的信息量和信息量的权重来确定目标损失函数,不仅实现了通过像素的信息量来调整像素的权重,还实现了通过每个预设图像类别的总信息量来调整各预设图像类别之间的权重,解决了知识蒸馏在语义分割场景中的样本图像的像素信息量不均和像素类别不均的问题,从而提升了学生模型整体的预测准确性。
应该理解的是,虽然如上所述的各实施例所涉及的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,如上所述的各实施例所涉及的流程图中的至少一部分步骤可以包括多个步骤或者多个阶段,这些步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤中的步骤或者阶段的至少一部分轮流或者交替地执行。
基于同样的发明构思,本申请实施例还提供了一种用于实现上述所涉及的模型训练方法的模型训练装置。该装置所提供的解决问题的实现方案与上述方法中所记载的实现方案相似,故下面所提供的一个或多个模型训练装置实施例中的具体限定可以参见上文中对于模型训练方法的限定,在此不再赘述。
在一个实施例中,如图4所示,提供了一种模型训练装置400,包括:像素预测模块401、信息提取模块402、函数获取模块403和模型训练模块404,其中:
像素预测模块401,用于获取已训练的教师模型对样本图像中各像素的类别预测结果。
信息提取模块402,用于根据已训练的教师模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息量。
函数获取模块403,用于在各预设图像类别下,基于样本图像中各像素的信息量、已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到待训练的学生模型的目标损失函数。
模型获取模块404,用于根据目标损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型;训练完成的学生模型用于对输入的图像进行语义分割。
在一个实施例中,函数获取模块403,还用于在各预设图像类别下,根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,确定样本图像中各像素的信息量的权重;根据各预设图像类别下样本图像中各像素的信息量和样本图像中各像素的信息量的权重,得到待训练的学生模型的目标损失函数。
在一个实施例中,模型训练装置400还包括权重确定模块,用于根据样本图像中各像素的类别预测结果,从各预设图像类别中确定出样本图像中各像素所属的图像类别;根据已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息散度;信息散度表示已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果之间的距离;在各预设图像类别下,根据样本图像中各像素的信息散度,依次确定样本图像中所属的图像类别与预设图像类别相同的像素的信息量的权重,得到样本图像中各像素的信息量的权重。
在一个实施例中,模型训练装置400还包括函数确定模块,用于在各预设图像类别下,分别根据样本图像中所属的图像类别与预设图像类别相同的像素的信息量和与预设图像类别相同的像素的信息量的权重,确定样本图像在各预设图像类别下的总信息量;根据样本图像在各预设图像类别下的总信息量之和的平均值,得到待训练的学生模型的目标损失函数。
在一个实施例中,信息提取模块402,还用于根据样本图像中各像素的类别预测结果,得到样本图像中各像素在各预设图像类别下的类别信息量;根据各像素在各预设图像类别下的类别信息量,得到各像素的信息量。
在一个实施例中,信息提取模块402,还用于分别从各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为各像素的目标预测结果;分别根据各像素的目标预测结果,得到样本图像中各像素的信息量。
在一个实施例中,模型训练装置400还包括第一条件确定模块,用于分别从各像素的类别预测结果中,筛选出类别预测概率最大的类别预测结果,作为各像素的目标预测结果。
在一个实施例中,模型训练装置400还包括第二条件确定模块,用于分别从各像素的类别预测结果中,筛选出类别预测概率最大和类别预测概率第二大的类别预测结果,作为各像素的目标预测结果。
在一个实施例中,模型获取模块404,还用于获取待训练的学生模型的初始损失函数;根据初始损失函数和目标损失函数,得到总损失函数;根据总损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型。
在一个实施例中,模型训练装置400还包括函数相加模块,用于将初始损失函数与目标损失函数进行相加,得到总损失函数。
上述模型训练装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图5所示。该计算机设备包括通过系统总线连接的处理器、存储器和网络接口。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质和内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储样本图像、样本图像中各像素的预测结果和各像素的信息量等数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种模型训练方法。
本领域技术人员可以理解,图5中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,还提供了一种计算机设备,包括存储器和处理器,存储器中存储有计算机程序,该处理器执行计算机程序时实现上述各方法实施例中的步骤。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述各方法实施例中的步骤。
在一个实施例中,提供了一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现上述各方法实施例中的步骤。
需要说明的是,本申请所涉及的用户信息(包括但不限于用户设备信息、用户个人信息等)和数据(包括但不限于用于分析的数据、存储的数据、展示的数据等),均为经用户授权或者经过各方充分授权的信息和数据。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、数据库或其它介质的任何引用,均可包括非易失性和易失性存储器中的至少一种。非易失性存储器可包括只读存储器(Read-OnlyMemory,ROM)、磁带、软盘、闪存、光存储器、高密度嵌入式非易失性存储器、阻变存储器(ReRAM)、磁变存储器(Magnetoresistive Random Access Memory,MRAM)、铁电存储器(Ferroelectric Random Access Memory,FRAM)、相变存储器(Phase Change Memory,PCM)、石墨烯存储器等。易失性存储器可包括随机存取存储器(Random Access Memory,RAM)或外部高速缓冲存储器等。作为说明而非局限,RAM可以是多种形式,比如静态随机存取存储器(Static Random Access Memory,SRAM)或动态随机存取存储器(Dynamic RandomAccess Memory,DRAM)等。本申请所提供的各实施例中所涉及的数据库可包括关系型数据库和非关系型数据库中至少一种。非关系型数据库可包括基于区块链的分布式数据库等,不限于此。本申请所提供的各实施例中所涉及的处理器可为通用处理器、中央处理器、图形处理器、数字信号处理器、可编程逻辑器、基于量子计算的数据处理逻辑器等,不限于此。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对本申请专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请的保护范围应以所附权利要求为准。
Claims (12)
1.一种模型训练方法,其特征在于,所述方法包括:
获取已训练的教师模型对样本图像中各像素的类别预测结果;
根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;所述信息量表示所述样本图像中像素的不确定性;
在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;
根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割;
所述在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数,包括:
在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和所述待训练的学生模型对所述样本图像中各像素的类别预测结果,确定所述样本图像中各像素的信息量在对应图像类别下的权重;
根据各预设图像类别下样本图像中各像素的信息量,以及所述样本图像中各像素的信息量在对应图像类别下的权重,得到所述待训练的学生模型的目标损失函数。
2.根据权利要求1所述的方法,其特征在于,所述在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和所述待训练的学生模型对所述样本图像中各像素的类别预测结果,确定所述样本图像中各像素的信息量在对应图像类别下的权重,包括:
根据所述样本图像中各像素的类别预测结果,从所述各预设图像类别中确定出所述样本图像中各像素所属的图像类别;
根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息散度;所述信息散度表示所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果之间的距离;
在所述各预设图像类别下,根据所述样本图像中各像素的信息散度,依次确定所述样本图像中所属的图像类别与所述预设图像类别相同的像素的信息量的权重,得到所述样本图像中各像素的信息量在对应图像类别下的权重。
3.根据权利要求2所述的方法,其特征在于,所述根据所述样本图像中各像素的信息量和所述样本图像中各像素的信息量在对应图像类别下的权重,得到所述待训练的学生模型的目标损失函数,包括:
在所述各预设图像类别下,分别根据样本图像中所属的图像类别与所述预设图像类别相同的像素的信息量和所述与所述预设图像类别相同的像素的信息量的权重,确定所述样本图像在所述各预设图像类别下的总信息量;
根据所述样本图像在所述各预设图像类别下的总信息量之和的平均值,得到所述待训练的学生模型的目标损失函数。
4.根据权利要求1所述的方法,其特征在于,所述根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量,包括:
根据所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素在所述各预设图像类别下的类别信息量;
根据所述各像素在所述各预设图像类别下的类别信息量,得到所述各像素的信息量。
5.根据权利要求1所述的方法,其特征在于,所述根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量,还包括:
分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果;
分别根据所述各像素的目标预测结果,得到所述样本图像中各像素的信息量。
6.根据权利要求5所述的方法,其特征在于,所述分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果,包括:
分别从所述各像素的类别预测结果中,筛选出类别预测概率最大的类别预测结果,作为所述各像素的目标预测结果。
7.根据权利要求5所述的方法,其特征在于,所述分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果,包括:
分别从所述各像素的类别预测结果中,筛选出类别预测概率最大和类别预测概率第二大的类别预测结果,作为所述各像素的目标预测结果。
8.根据权利要求1所述的方法,其特征在于,所述根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型,包括:
获取所述待训练的学生模型的初始损失函数;
根据所述初始损失函数和所述目标损失函数,得到总损失函数;
根据所述总损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型。
9.根据权利要求8所述的方法,其特征在于,所述根据所述初始损失函数和所述目标损失函数,得到总损失函数,包括:
将所述初始损失函数与所述目标损失函数进行相加,得到所述总损失函数。
10.一种模型训练装置,其特征在于,所述装置包括:
像素预测模块,用于获取已训练的教师模型对样本图像中各像素的类别预测结果;
信息提取模块,用于根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;所述信息量表示所述样本图像中像素的不确定性;
函数获取模块,用于在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;
模型获取模块,用于根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割;
所述函数获取模块,还用于在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和所述待训练的学生模型对所述样本图像中各像素的类别预测结果,确定所述样本图像中各像素的信息量在对应图像类别下的权重;根据各预设图像类别下样本图像中各像素的信息量,以及所述样本图像中各像素的信息量在对应图像类别下的权重,得到所述待训练的学生模型的目标损失函数。
11.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至9中任一项所述的方法的步骤。
12.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至9中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210274888.2A CN114359563B (zh) | 2022-03-21 | 2022-03-21 | 模型训练方法、装置、计算机设备和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210274888.2A CN114359563B (zh) | 2022-03-21 | 2022-03-21 | 模型训练方法、装置、计算机设备和存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114359563A CN114359563A (zh) | 2022-04-15 |
CN114359563B true CN114359563B (zh) | 2022-06-28 |
Family
ID=81094714
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210274888.2A Active CN114359563B (zh) | 2022-03-21 | 2022-03-21 | 模型训练方法、装置、计算机设备和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114359563B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115115828A (zh) * | 2022-04-29 | 2022-09-27 | 腾讯医疗健康(深圳)有限公司 | 数据处理方法、装置、程序产品、计算机设备和介质 |
CN115690592B (zh) * | 2023-01-05 | 2023-04-25 | 阿里巴巴(中国)有限公司 | 图像处理方法和模型训练方法 |
CN116071608B (zh) * | 2023-03-16 | 2023-06-06 | 浙江啄云智能科技有限公司 | 目标检测方法、装置、设备和存储介质 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111639524A (zh) * | 2020-04-20 | 2020-09-08 | 中山大学 | 一种自动驾驶图像语义分割优化方法 |
CN112132197A (zh) * | 2020-09-15 | 2020-12-25 | 腾讯科技(深圳)有限公司 | 模型训练、图像处理方法、装置、计算机设备和存储介质 |
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN113538441A (zh) * | 2021-01-06 | 2021-10-22 | 腾讯科技(深圳)有限公司 | 图像分割模型的处理方法、图像处理方法及装置 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110147836B (zh) * | 2019-05-13 | 2021-07-02 | 腾讯科技(深圳)有限公司 | 模型训练方法、装置、终端及存储介质 |
-
2022
- 2022-03-21 CN CN202210274888.2A patent/CN114359563B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111639524A (zh) * | 2020-04-20 | 2020-09-08 | 中山大学 | 一种自动驾驶图像语义分割优化方法 |
CN112132197A (zh) * | 2020-09-15 | 2020-12-25 | 腾讯科技(深圳)有限公司 | 模型训练、图像处理方法、装置、计算机设备和存储介质 |
CN113538441A (zh) * | 2021-01-06 | 2021-10-22 | 腾讯科技(深圳)有限公司 | 图像分割模型的处理方法、图像处理方法及装置 |
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN114359563A (zh) | 2022-04-15 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114359563B (zh) | 模型训练方法、装置、计算机设备和存储介质 | |
CN111815432B (zh) | 金融服务风险预测方法及装置 | |
WO2020155300A1 (zh) | 一种模型预测方法及装置 | |
CN111783712A (zh) | 一种视频处理方法、装置、设备及介质 | |
US20230102640A1 (en) | System and methods for machine learning training data selection | |
CN111639230B (zh) | 一种相似视频的筛选方法、装置、设备和存储介质 | |
CN113343091A (zh) | 面向产业和企业的科技服务推荐计算方法、介质及程序 | |
CN110807693A (zh) | 专辑的推荐方法、装置、设备和存储介质 | |
CN111753729B (zh) | 一种假脸检测方法、装置、电子设备及存储介质 | |
CN113609337A (zh) | 图神经网络的预训练方法、训练方法、装置、设备及介质 | |
US20220004849A1 (en) | Image processing neural networks with dynamic filter activation | |
CN112465847A (zh) | 一种基于预测清晰边界的边缘检测方法、装置及设备 | |
CN111291795A (zh) | 人群特征分析方法、装置、存储介质和计算机设备 | |
CN108229572B (zh) | 一种参数寻优方法及计算设备 | |
CN113223017A (zh) | 目标分割模型的训练方法、目标分割方法及设备 | |
Wang et al. | Complementary boundary estimation network for temporal action proposal generation | |
CN114283350B (zh) | 视觉模型训练和视频处理方法、装置、设备及存储介质 | |
KR101991043B1 (ko) | 비디오 서머리 방법 | |
CN116756426A (zh) | 项目推荐方法、装置、计算机设备和存储介质 | |
CN116881543A (zh) | 金融资源对象推荐方法、装置、设备、存储介质和产品 | |
CN114168854A (zh) | 信息推荐方法、装置和计算机设备 | |
CN116775989A (zh) | 一种基于时效性辅助任务驱动的个性化论文推荐方法 | |
CN116910604A (zh) | 用户分类方法、装置、计算机设备、存储介质和程序产品 | |
CN118154300A (zh) | 抵质押参数处理方法、装置、计算机设备和存储介质 | |
CN117078427A (zh) | 产品推荐方法、装置、设备、存储介质和程序产品 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |