CN115879514A - 类相关性预测改进方法、装置、计算机设备及存储介质 - Google Patents

类相关性预测改进方法、装置、计算机设备及存储介质 Download PDF

Info

Publication number
CN115879514A
CN115879514A CN202211559865.2A CN202211559865A CN115879514A CN 115879514 A CN115879514 A CN 115879514A CN 202211559865 A CN202211559865 A CN 202211559865A CN 115879514 A CN115879514 A CN 115879514A
Authority
CN
China
Prior art keywords
class
prediction
network
loss
loss function
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202211559865.2A
Other languages
English (en)
Other versions
CN115879514B (zh
Inventor
杜杰
王晶
刘鹏
汪天富
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen University
Original Assignee
Shenzhen University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen University filed Critical Shenzhen University
Priority to CN202211559865.2A priority Critical patent/CN115879514B/zh
Publication of CN115879514A publication Critical patent/CN115879514A/zh
Application granted granted Critical
Publication of CN115879514B publication Critical patent/CN115879514B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明实施例公开了类相关性预测改进方法、装置、计算机设备及存储介质。所述方法包括:获取待预测图像,以得到样本集;构建预测网络以及损失函数;利用所述样本集对所述预测网络进行训练,对训练所得的结果进行规范化处理,以得到预测概率值;采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整;其中,所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。通过实施本发明实施例的方法可实现在保证准确率的情况下,利用先验知识来拉近相似的两类的网络输出值,拉远不相似的两类的网络输出值,提高神经网络的泛化能力的目的。

Description

类相关性预测改进方法、装置、计算机设备及存储介质
技术领域
本发明涉及图像分类方法,更具体地说是指类相关性预测改进方法、装置、计算机设备及存储介质。
背景技术
图像分类是计算机视觉研究中一个重要的基本问题,是其他视觉任务的基础,如医学图像分类任务在临床治疗中具有重要的辅助作用。然而,如支持向量机方法等传统的图像分类方法在性能上已经达到了极限,并且使用它们需要花费大量的时间和精力来选择和提取分类特征,深度神经网络作为一种新兴的机器学习方法,其在不同分类任务中的潜力已被证明,值得一提的是,卷积神经网络在不同的图像分类任务中取得了最先进的性能。现今,研究人员一直在研究通过更深或更宽的网络结构来提高分类精度的办法,然而深度神经网络存在因参数量大而导致的泛化能力不佳的问题。针对提升网络泛化性能的问题,主流的方法是在神经网络中引入正则化方法。
已有工作提出一种正则化方法来惩罚相似样本之间的预测分布以达到提高神经网络的泛化性能。通过先验知识,可以知道在特征空间上相似的两类会比不相似的两类距离更近,即猫和狗之间的距离会小于猫和卡车之间的距离。而且规范化卷积神经网络的预测值是有效的,因为模型含有最简洁的知识。由此可知,对相似的两类的预测值应该拉近,不相似的两类的预测值应该拉远。该工作已考虑到拉近同类的距离,但未考虑到类间的相关性,即对不相似的两类的预测值应该拉远,因此其对提高分类网络的泛化性能上仍存有局限性;传统的交叉熵损失函数并未考虑到这种预测分布的一致性,也就会导致对某类的预测值较大,而对其他类的预测值较小,无法指出样本之间的相似性与差异性。
因此,有必要设计一种新的方法,实现在保证准确率的情况下,利用先验知识来拉近相似的两类的预测值,拉远不相似的两类的预测值,提高神经网络的泛化能力的目的。
发明内容
本发明的目的在于克服现有技术的缺陷,提供类相关性预测改进方法、装置、计算机设备及存储介质。
为实现上述目的,本发明采用以下技术方案:类相关性预测改进方法,包括:
获取待预测图像,以得到样本集;
构建预测网络以及损失函数;
利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值;
采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整;
其中,所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。
其进一步技术方案为:所述利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值,包括:
将所述样本集输入至所述预测网络中进行特征提取,以得到网络输出值;
对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
其进一步技术方案为:所述所述对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值,包括:
对所述网络输出值除以预设参数形成softmax层的输入值,并在softmax层计算预测概率值
其进一步技术方案为:所述类内损失函数为class_intra_loss=KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的相似类的样本概率值;KL为相似类之间的KL散度;class_intra_loss为类内损失值。
其进一步技术方案为:所述类间损失函数为class_inter_loss=c/KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的不相似类的样本概率值;KL为两类之间的KL散度;class_inter_loss为类间损失值;c是常数。
其进一步技术方案为:所述预测网络为resnet18网络。
本发明还提供了类相关性预测改进装置,包括:
图像获取单元,用于获取待预测图像,以得到样本集;
构建单元,用于构建预测网络以及损失函数;
训练单元,用于利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值;
调整单元,用于采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整;
其中,所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。
其进一步技术方案为:所述训练单元包括:
运算子单元,用于将所述样本集输入至所述预测网络中进行特征提取,以得到网络输出值;
预测子单元,用于对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
本发明还提供了一种计算机设备,所述计算机设备包括存储器及处理器,所述存储器上存储有计算机程序,所述处理器执行所述计算机程序时实现上述的方法。
本发明还提供了一种存储介质,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时可实现上述的方法。
本发明与现有技术相比的有益效果是:本发明通过在训练预测网络时,在将网络输出经过softmax时会让其除以参数T作为softmax的输入,对网络输出值进行规范化,进行自知识蒸馏,且对于损失函数,设置交叉熵损失函数以及类损失函数,类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数,实现相似类接近,而不相似类远离,实现在保证准确率的情况下,利用先验知识来拉近相似的两类的网络输出值,拉远不相似的两类的网络输出值,提高神经网络的泛化能力的目的。
下面结合附图和具体实施例对本发明作进一步描述。
附图说明
为了更清楚地说明本发明实施例技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的类相关性预测改进方法的应用场景示意图;
图2为本发明实施例提供的类相关性预测改进方法的流程示意图;
图3为本发明实施例提供的类相关性预测改进方法的子流程示意图;
图4为本发明实施例提供的类相关性预测改进装置的示意性框图;
图5为本发明实施例提供的类相关性预测改进装置的训练单元的示意性框图;
图6为本发明实施例提供的计算机设备的示意性框图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在此本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
请参阅图1和图2,图1为本发明实施例提供的类相关性预测改进方法的应用场景示意图。图2为本发明实施例提供的类相关性预测改进方法的示意性流程图。该类相关性预测改进方法应用于服务器中,该服务器与终端进行数据交互,终端输入待预测图像,以对服务器内的预测网络进行训练,训练网络时直接对知识进行正则化,在将网络输出经过softmax时会让其除以参数T作为softmax的输入,对网络输出值进行规范化,相当于对原来的知识进行了蒸馏提炼,损失函数引入了交叉熵损失函数CrossEntroyLoss,CrossEntroyLoss作为多分类的一个损失函数,在分类任务上具有良好的性能;同时,设计了一个类损失函数,类损失函数包括类内损失和类间损失,类内损失是为了拉近相似类的网络输出值产生的损失,同理,类间损失是通过拉远不相似类的预测产生的损失,通过loss的反向传播直至收敛。
图2是本发明实施例提供的类相关性预测改进方法的流程示意图。如图2所示,该方法包括以下步骤S110至S140。
S110、获取待预测图像,以得到样本集。
在本实施例中,样本集是指用于训练预测网络的图像。
S120、构建预测网络以及损失函数。
在本实施例中,所述预测网络为但不局限于resnet18网络。所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。
具体地,所述类内损失函数为class_intra_loss=KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的相似类的样本概率值;KL为相似类之间的KL散度;class_intra_loss为类内损失值。
所述类间损失函数为class_inter_loss=c/KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的不相似类的样本概率值;KL为两类之间的KL散度;class_inter_loss为类间损失值;c是常数。KL是用来衡量两个类相似与否,这里公式是考虑类间距离,KL是不相似类之间的KL散度。
损失函数由两部分组成,第一部分损失是常用于多分类的CrossEntroyLoss(记为ce_loss),第二部分损失是类损失函数(记为class_loss),class_loss又分为类内损失(class_intra_loss)和类间损失(class_inter_loss),其中class_intra_loss计算的相似类之间的KL(相对熵,Kullback-Leibler divergence)散度,KL散度是衡量两个概率值之间的差异,其值越小代表两个概率越相似,因此通过反向传播,会让相似类的预测概率值不断接近。class_intra_loss的计算公式为class_intra_loss=KL(y_predict(x)||y_predict(x));y_predict(x)为某个类的预测概率值,即某个类的概率值,如一个batch大小的label为猫的预测概率值,y_predict(x)是依据先验知识随机采样的相似类的样本概率值,如随机采样一个batch大小的label为‘狗’样本,y_predict(x)是其预测概率值。同理,class_inter_loss也通过KL散度来计算,公式为class_inter_loss=c/KL(y_predict(x)||y_predict(x)),y_predict(x)为某个类的概率值(如一个batch大小的label为猫的预测概率值),y_predict(x)则是依据先验知识随机采样的不相似类的样本概率值,如随机采样一个batch大小的label为‘卡车’样本,y_predict(x′)是其预测概率值。与class_intra_loss不同的是,不能直接取KL散度的结果,为了拉远类间距离,即两个概率值相差越大效果越好,即class_inter_loss与KL散度是一个反比的关系,在这里取的倒数,class_inter_loss=c/KL(y_predict(x)||y_predict(x))中c是常数,用来提升class_inter_loss值大小。最后,class_loss由上述两个loss组成,即class_loss=class_inter_loss+class_intra_loss。
最后,总的损失函数计算公式如下loss=ce_loss+class_loss。
选用resnet18的原因是其计算速度较快,且其残差块的结构很好地保留了低维空间的特征,在图像分类任务上性能较好,将第一层的卷积核大小由原来的7*7改成了3*3。
S130、利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值。
在本实施例中,预测概率值是指预测概率结果。
在一实施例中,请参阅图3,上述的步骤S130可包括步骤S131~S132。
S131、将所述样本集输入至所述预测网络中进行特征提取,以得到网络输出值。
在本实施例中,网络输出值是指样本集对应的图像类别的所有输出值。
S132、对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
具体地,对所述网络输出值除以预设参数形成softmax层的输入值,并在softmax层计算预测概率值。具体地,对所述网络输出值采用
Figure BDA0003984199070000071
转换为每个种类的概率值,
zi和zj是一个图像的网络输出值的其中一个数值,对于一张图像来说,它在每个类别上都会有一个值,就如三分类中网络输出值为[2,3,5],那么这里的zi就是其中的2或者3或者5,pi(也是一个值)就是经过
Figure BDA0003984199070000072
运算得到的概率值,原先的输出结果output就相当于这里的[2,3,5],它是一个图像的网络输出值。
在本实施例中,经过resnet18网络后,得到一个shape为[batch_size,C]大小的output(batch_size是batch的大小,C是预测的类的总数),如
Figure BDA0003984199070000073
所示,其中,zi以及zj是某个样本的网络输出的其中一个值,T是一个常量,用于提升负标签(即模型对其他类的预测)的信息,取2或者4,pi是计算后的一个概率值。通过/>
Figure BDA0003984199070000074
将网络输出值转换为概率值,同时提升了对前几类的预测概率值。
记经过
Figure BDA0003984199070000075
后得到的预测概率值为:/>
Figure BDA0003984199070000076
原始的softmax函数计算公式如下:
Figure BDA0003984199070000077
由于softmax函数引入了指数函数,也就是说在x轴上一个很小的变化会导致y轴上很大的变化,会将输出的数值拉开距离,直接使用softmax层的输出值作为预测概率值的话,会导致对某类的网络输出值较大,而对其他类的网络输出值较小。而加入T这个变量就可以缓解这个问题。/>
Figure BDA0003984199070000078
是加入了T之后的softmax函数,原来的softmax函数/>
Figure BDA0003984199070000079
是T=1的特例。T越高,softmax的计算结果越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,也就提升了对其他类的预测概率值。假设现在有一个网络做的是三分类预测(猫,狗,卡车),其中有个样本的网络输出output是[5,4,1],那么经过softmax后其结果是[0.7214,0.2654,0.0132],而令T=2,经过softmax后其结果是[0.5741,0.3482,0.0777],可以看出加入T后提升了对其他类的预测,更好地保留了负标签的信息,这为拉近相似类,拉远不相似类的网络输出值奠定了基础。
传统的分类网络将网络输出转换成概率的方式是让网络输出直接经过softmax层,与之不同的是,为了提高网络对前几类预测的概率值,在将网络输出经过softmax时会让其除以参数T作为softmax的输入,对网络输出值进行规范化,相当于对原来的知识进行了蒸馏提炼。
S140、采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整。
利用损失函数计算预测概率值的损失值,以损失值确定训练后的预测网络是否收敛,未收敛,则重新调整预测网络后,进行再次训练,此过程属于现有技术,此处不再赘述。
本实施例的方法在进行分类时同时考虑到类内距离、类间距离和类间相关性,在保证准确率的情况下利用先验知识来拉近相似的两类的网络输出值,拉远不相似的两类的网络输出值,由此达到提高神经网络的泛化能力的目的。其次,一些工作使用知识蒸馏将知识从预先训练的复杂教师网络模型迁移到学生网络模型,但本实施例的方法在训练网络时直接对知识进行正则化,称之为自知识蒸馏。传统的分类网络将网络输出转换成概率的方式是让网络输出直接经过softmax层,与之不同的是,为了提高网络对前几类预测的概率值,在将网络输出经过softmax时会让其除以参数T作为softmax的输入,对网络输出值进行规范化,相当于对原来的知识进行了蒸馏提炼。通过改进损失函数的方法来实现拉近同类的网络输出值以及拉远不同类的网络输出值,这比修改网络结构等其他方法更加简便且易于复现。具体来说,损失函数引入了CrossEntroyLoss,CrossEntroyLoss作为多分类的一个损失函数,在分类任务上具有良好的性能;同时,设计了一个类损失函数,类损失函数包括类内损失和类间损失,类内损失是为了拉近相似类的网络输出值产生的损失,同理,类间损失是通过拉远不相似类的预测产生的损失,通过loss的反向传播直至收敛,最后本实施例的方法能达到的效果如表1所示,对于ground-truth为猫的样本,在预测上拉近猫和狗的距离,同时拉远猫和卡车的距离,同理,对于ground-truth为卡车的样本,我们会拉远猫、狗与卡车之间的距离。
表1.考虑类间相关性后的预测概率值
Figure BDA0003984199070000081
Figure BDA0003984199070000091
上述的类相关性预测方法,通过在训练预测网络时,在将网络输出经过softmax时会让其除以参数T作为softmax的输入,对网络输出值进行规范化,进行自知识蒸馏,且对于损失函数,设置交叉熵损失函数以及类损失函数,类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数,实现相似类接近,而不相似类远离,实现在保证准确率的情况下,利用先验知识来拉近相似的两类的网络输出值,拉远不相似的两类的网络输出值,提高神经网络的泛化能力的目的。
图4是本发明实施例提供的一种类相关性预测改进装置300的示意性框图。如图4所示,对应于以上类相关性预测改进方法,本发明还提供一种类相关性预测改进装置300。该类相关性预测改进装置300包括用于执行上述类相关性预测改进方法的单元,该装置可以被配置于服务器中。具体地,请参阅图4,该类相关性预测改进装置300包括图像获取单元301、构建单元302、训练单元303以及调整单元304。
图像获取单元301,用于获取待预测图像,以得到样本集;构建单元302,用于构建预测网络以及损失函数;训练单元303,用于利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值;调整单元304,用于采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整;其中,所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。所述类内损失函数为class_intra_loss=KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的相似类的样本概率值;KL为相似类之间的KL散度;class_intra_loss为类内损失值。所述类间损失函数为class_inter_loss=c/KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的相似类的样本概率值;KL为两类之间的KL散度;class_inter_loss为类间损失值;c是常数。所述预测网络为resnet18网络。
在一实施例中,如图5所示,所述训练单元303包括运算子单元3031以及预测子单元3032。
运算子单元3031,用于将所述样本集输入至所述预测网络中进行特征提取,以得到网络输出值;预测子单元3032,用于对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
在一实施例中,所述输入子单元3032,用于对所述网络输出值除以预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
需要说明的是,所属领域的技术人员可以清楚地了解到,上述类相关性预测改进装置300和各单元的具体实现过程,可以参考前述方法实施例中的相应描述,为了描述的方便和简洁,在此不再赘述。
上述类相关性预测改进装置300可以实现为一种计算机程序的形式,该计算机程序可以在如图6所示的计算机设备上运行。
请参阅图6,图6是本申请实施例提供的一种计算机设备的示意性框图。该计算机设备500可以是服务器,其中,服务器可以是独立的服务器,也可以是多个服务器组成的服务器集群。
参阅图6,该计算机设备500包括通过系统总线501连接的处理器502、存储器和网络接口505,其中,存储器可以包括非易失性存储介质503和内存储器504。
该非易失性存储介质503可存储操作系统5031和计算机程序5032。该计算机程序5032包括程序指令,该程序指令被执行时,可使得处理器502执行一种类相关性预测改进方法。
该处理器502用于提供计算和控制能力,以支撑整个计算机设备500的运行。
该内存储器504为非易失性存储介质503中的计算机程序5032的运行提供环境,该计算机程序5032被处理器502执行时,可使得处理器502执行一种类相关性预测改进方法。
该网络接口505用于与其它设备进行网络通信。本领域技术人员可以理解,图6中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备500的限定,具体的计算机设备500可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
其中,所述处理器502用于运行存储在存储器中的计算机程序5032,以实现如下步骤:
获取待预测图像,以得到样本集;构建预测网络以及损失函数;利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值;采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整;
其中,所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。
所述类内损失函数为class_intra_loss=KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的相似类的样本概率值;KL为相似类之间的KL散度;class_intra_loss为类内损失值。
所述类间损失函数为class_inter_loss=c/KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的不相似类的样本概率值;KL为相似类之间的KL散度;class_inter_loss为类间损失值;c是常数。
所述预测网络为resnet18网络。
在一实施例中,处理器502在实现所述利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值步骤时,具体实现如下步骤:
将所述样本集输入至所述预测网络中进行特征提取,以得到网络输出值;对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
在一实施例中,处理器502在实现所述对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值步骤时,具体实现如下步骤:
对所述网络输出值除以预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
应当理解,在本申请实施例中,处理器502可以是中央处理单元(CentralProcessing Unit,CPU),该处理器502还可以是其他通用处理器、数字信号处理器(DigitalSignal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。其中,通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
本领域普通技术人员可以理解的是实现上述实施例的方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成。该计算机程序包括程序指令,计算机程序可存储于一存储介质中,该存储介质为计算机可读存储介质。该程序指令被该计算机系统中的至少一个处理器执行,以实现上述方法的实施例的流程步骤。
因此,本发明还提供一种存储介质。该存储介质可以为计算机可读存储介质。该存储介质存储有计算机程序,其中该计算机程序被处理器执行时使处理器执行如下步骤:
获取待预测图像,以得到样本集;构建预测网络以及损失函数;利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值;采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整;
其中,所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。
所述类内损失函数为class_intra_loss=KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的不相似类的样本概率值;KL为相似类之间的KL散度;class_intra_loss为类内损失值。
所述类间损失函数为class_inter_loss=c/KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的相似类的样本概率值;KL为相似类之间的KL散度;class_inter_loss为类间损失值;c是常数。
所述预测网络为resnet18网络。
在一实施例中,所述处理器在执行所述计算机程序而实现所述利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值步骤时,具体实现如下步骤:
将所述样本集输入至所述预测网络中进行特征提取,以得到网络输出值;对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
在一实施例中,所述处理器在执行所述计算机程序而实现所述对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值步骤时,具体实现如下步骤:
对所述网络输出值除以预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
所述存储介质可以是U盘、移动硬盘、只读存储器(Read-Only Memory,ROM)、磁碟或者光盘等各种可以存储程序代码的计算机可读存储介质。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
在本发明所提供的几个实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的。例如,各个单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式。例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。
本发明实施例方法中的步骤可以根据实际需要进行顺序调整、合并和删减。本发明实施例装置中的单元可以根据实际需要进行合并、划分和删减。另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以是两个或两个以上单元集成在一个单元中。
该集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分,或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,终端,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (10)

1.类相关性预测改进方法,其特征在于,包括:
获取待预测图像,以得到样本集;
构建预测网络以及损失函数;
利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值;
采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整;
其中,所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。
2.根据权利要求1所述的类相关性预测改进方法,其特征在于,所述利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值,包括:
将所述样本集输入至所述预测网络中进行特征提取,以得到网络输出值;
对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
3.根据权利要求2所述的类相关性预测改进方法,其特征在于,所述对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值,包括:
对所述网络输出值除以预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
4.根据权利要求1所述的类相关性预测改进方法,其特征在于,所述类内损失函数为class_intra_loss=KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的相似类的样本概率值;KL为相似类之间的KL散度;class_intra_loss为类内损失值。
5.根据权利要求1所述的类相关性预测改进方法,其特征在于,所述类间损失函数为class_inter_loss=c/KL(y_predict(x)||y_predict(x));其中,y_predict(x)为某个类的预测概率值;y_predict(x)为依据先验知识随机采样的不相似类的样本概率值;KL为两类之间的KL散度;class_inter_loss为类间损失值;c是常数。
6.根据权利要求1所述的类相关性预测改进方法,其特征在于,所述预测网络为resnet18网络。
7.类相关性预测改进装置,其特征在于,包括:
图像获取单元,用于获取待预测图像,以得到样本集;
构建单元,用于构建预测网络以及损失函数;
训练单元,用于利用所述样本集对所述预测网络进行训练,并对训练所得的结果进行规范化处理,以得到预测概率值;
调整单元,用于采用所述损失函数结合所述预测概率值对训练后的预测网络进行调整;
其中,所述损失函数包括交叉熵损失函数以及类损失函数;所述类损失函数包括用于拉近相似类的网络输出值产生的损失的类内损失函数以及用于拉远不相似类的预测产生的损失的类间损失函数。
8.根据权利要求7所述的类相关性预测改进装置,其特征在于,所述训练单元包括:
运算子单元,用于将所述样本集输入至所述预测网络中进行特征提取,以得到网络输出值;
预测子单元,用于对所述网络输出值结合预设参数形成softmax层的输入值,并在softmax层计算预测概率值。
9.一种计算机设备,其特征在于,所述计算机设备包括存储器及处理器,所述存储器上存储有计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至6中任一项所述的方法。
10.一种存储介质,其特征在于,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时可实现如权利要求1至6中任一项所述的方法。
CN202211559865.2A 2022-12-06 2022-12-06 类相关性预测改进方法、装置、计算机设备及存储介质 Active CN115879514B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211559865.2A CN115879514B (zh) 2022-12-06 2022-12-06 类相关性预测改进方法、装置、计算机设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211559865.2A CN115879514B (zh) 2022-12-06 2022-12-06 类相关性预测改进方法、装置、计算机设备及存储介质

Publications (2)

Publication Number Publication Date
CN115879514A true CN115879514A (zh) 2023-03-31
CN115879514B CN115879514B (zh) 2023-08-04

Family

ID=85766200

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211559865.2A Active CN115879514B (zh) 2022-12-06 2022-12-06 类相关性预测改进方法、装置、计算机设备及存储介质

Country Status (1)

Country Link
CN (1) CN115879514B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116226755A (zh) * 2023-05-10 2023-06-06 广东维信智联科技有限公司 一种基于大数据的实时数据识别方法

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109002845A (zh) * 2018-06-29 2018-12-14 西安交通大学 基于深度卷积神经网络的细粒度图像分类方法
CN111538823A (zh) * 2020-04-26 2020-08-14 支付宝(杭州)信息技术有限公司 信息处理方法、模型训练方法、装置、设备及介质
US10769766B1 (en) * 2018-05-31 2020-09-08 Amazon Technologies, Inc. Regularized multi-label classification from partially labeled training data
CN112614571A (zh) * 2020-12-24 2021-04-06 中国科学院深圳先进技术研究院 神经网络模型的训练方法、装置、图像分类方法和介质
CN113850179A (zh) * 2020-10-27 2021-12-28 深圳市商汤科技有限公司 图像检测方法及相关模型的训练方法、装置、设备、介质
US20220051017A1 (en) * 2020-08-11 2022-02-17 Nvidia Corporation Enhanced object identification using one or more neural networks
CN115240011A (zh) * 2022-08-11 2022-10-25 昂坤视觉(北京)科技有限公司 图像分类方法、装置、计算机可读存储介质及计算机设备

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10769766B1 (en) * 2018-05-31 2020-09-08 Amazon Technologies, Inc. Regularized multi-label classification from partially labeled training data
CN109002845A (zh) * 2018-06-29 2018-12-14 西安交通大学 基于深度卷积神经网络的细粒度图像分类方法
CN111538823A (zh) * 2020-04-26 2020-08-14 支付宝(杭州)信息技术有限公司 信息处理方法、模型训练方法、装置、设备及介质
US20220051017A1 (en) * 2020-08-11 2022-02-17 Nvidia Corporation Enhanced object identification using one or more neural networks
CN113850179A (zh) * 2020-10-27 2021-12-28 深圳市商汤科技有限公司 图像检测方法及相关模型的训练方法、装置、设备、介质
CN112614571A (zh) * 2020-12-24 2021-04-06 中国科学院深圳先进技术研究院 神经网络模型的训练方法、装置、图像分类方法和介质
CN115240011A (zh) * 2022-08-11 2022-10-25 昂坤视觉(北京)科技有限公司 图像分类方法、装置、计算机可读存储介质及计算机设备

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116226755A (zh) * 2023-05-10 2023-06-06 广东维信智联科技有限公司 一种基于大数据的实时数据识别方法

Also Published As

Publication number Publication date
CN115879514B (zh) 2023-08-04

Similar Documents

Publication Publication Date Title
US11410038B2 (en) Frame selection based on a trained neural network
CN111553480B (zh) 图像数据处理方法、装置、计算机可读介质及电子设备
CN110309874B (zh) 负样本筛选模型训练方法、数据筛选方法和数据匹配方法
EP4080416A1 (en) Adaptive search method and apparatus for neural network
CN110717522A (zh) 图像分类网络的对抗防御方法及相关装置
US11755880B2 (en) Method and apparatus for optimizing and applying multilayer neural network model, and storage medium
US8761496B2 (en) Image processing apparatus for calculating a degree of similarity between images, method of image processing, processing apparatus for calculating a degree of approximation between data sets, method of processing, computer program product, and computer readable medium
WO2021088365A1 (zh) 确定神经网络的方法和装置
JP2008542911A (ja) メトリック埋め込みによる画像比較
CN110929836B (zh) 神经网络训练及图像处理方法和装置、电子设备、介质
CN109766476B (zh) 视频内容情感分析方法、装置、计算机设备及存储介质
CN115879514A (zh) 类相关性预测改进方法、装置、计算机设备及存储介质
CN113344016A (zh) 深度迁移学习方法、装置、电子设备及存储介质
CN114581868A (zh) 基于模型通道剪枝的图像分析方法和装置
Lugmayr et al. Normalizing flow as a flexible fidelity objective for photo-realistic super-resolution
CN112085175B (zh) 基于神经网络计算的数据处理方法和装置
CN117726602A (zh) 基于带状池化的息肉分割方法及系统
CN111262873B (zh) 一种基于小波分解的用户登录特征预测方法及其装置
CN108629381A (zh) 基于大数据的人群筛选方法及终端设备
CN111640438A (zh) 音频数据处理方法、装置、存储介质及电子设备
CN113448680A (zh) 价值模型的训练方法以及相关设备
CN115795355A (zh) 一种分类模型训练方法、装置及设备
CN110929731A (zh) 一种基于探路者智能搜索算法的医疗影像处理方法及装置
CN111785379B (zh) 脑功能连接预测方法、装置、计算机设备及存储介质
CN110533192B (zh) 强化学习方法、装置、计算机可读介质及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant