CN114359809A - 分类及分类模型的训练方法、装置、设备及介质 - Google Patents

分类及分类模型的训练方法、装置、设备及介质 Download PDF

Info

Publication number
CN114359809A
CN114359809A CN202210021106.4A CN202210021106A CN114359809A CN 114359809 A CN114359809 A CN 114359809A CN 202210021106 A CN202210021106 A CN 202210021106A CN 114359809 A CN114359809 A CN 114359809A
Authority
CN
China
Prior art keywords
classification
category
model
classified
class
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210021106.4A
Other languages
English (en)
Inventor
申世伟
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Dajia Internet Information Technology Co Ltd
Original Assignee
Beijing Dajia Internet Information Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Dajia Internet Information Technology Co Ltd filed Critical Beijing Dajia Internet Information Technology Co Ltd
Priority to CN202210021106.4A priority Critical patent/CN114359809A/zh
Publication of CN114359809A publication Critical patent/CN114359809A/zh
Pending legal-status Critical Current

Links

Images

Landscapes

  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)

Abstract

本公开关于一种分类及分类模型的训练方法、装置、设备及介质,所述分类方法包括:获取待分类对象;基于所述待分类对象,利用分类模型中的分类分支模型,得到所述待分类对象的第一类别和中间特征;基于所述第一类别以及所述中间特征,利用所述分类模型中的校验分支模型,得到第二类别;将所述第二类别确定为所述待分类对象的类别。根据本公开的分类及分类模型的训练方法、装置、设备及介质,在传统的分类模型基础上增加校验分支模型,从而实现对分类分支模型的分类结果进行二次校验,可进一步提高分类分支模型的准召率,提高分类的准确性。

Description

分类及分类模型的训练方法、装置、设备及介质
技术领域
本公开涉及人工智能领域,更具体地说,涉及一种分类及分类模型的训练方法、装置、设备及介质。
背景技术
目前,随着人工智能技术的高速发展,越来越多的应用场景通过调用分类模型对目标对象进行分类,分类模型的分类准确性直接影响针对目标对象的后续执行动作,例如,在视频精准推送或者视频内容审核场景中,视频分类的准确度可直接影响推送效果或者审核效率,因此,需不断提高分类模型的分类准确性,以提高在各种场景中针对目标对象的分类准确性。
发明内容
本公开提供一种分类及分类模型的训练方法、装置、设备及介质,以至少解决上述相关技术中的问题。
根据本公开实施例的第一方面,提供一种分类方法,包括:获取待分类对象;基于所述待分类对象,利用分类模型中的分类分支模型,得到所述待分类对象的第一类别和中间特征,所述第一类别为通过所述分类分支模型估计的所述待分类对象的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征;基于所述第一类别以及所述中间特征,利用所述分类模型中的校验分支模型,得到所述待分类对象的第二类别,所述第二类别为通过所述校验分支模型对所述第一类别进行校验之后得到的至少一个类别;将所述第二类别确定为所述待分类对象的类别。
可选地,所述校验分支模型包括类别特征获取模块、特征交互模块和分类模块,所述基于所述第一类别以及所述中间特征,利用分类模型中的校验分支模型,得到所述待分类对象的第二类别,包括:基于所述第一类别,利用所述类别特征获取模块,得到所述第一类别的类别特征;基于所述类别特征以及所述中间特征,利用所述特征交互模块,得到配置权重的中间特征,所述权重表示所述类别特征与所述中间特征之间的相关程度;基于所述配置权重的中间特征,利用所述分类模块,得到所述待分类对象的第二类别。
可选地,所述基于所述第一类别,利用所述类别特征获取模块,得到所述第一类别的类别特征,包括:对所述第一类别进行单热编码,得到所述第一类别的类别列表;根据所述类别列表从字典矩阵中查询得到所述类别特征,所述字典矩阵是维度空间中的向量表示,所述字典矩阵的行数与所述稀疏矩阵的列数相同,所述字典矩阵的列数与所述中间特征的维度相同。
可选地,所述基于所述待分类对象,利用分类模型中的分类分支模型,得到所述待分类对象的第一类别,包括:基于所述待分类对象,利用所述分类分支模型,获取所述待分类对象的至少一个第一类别概率;当所述至少一个第一类别概率大于或者等于所对应的第一阈值时,将所述至少一个第一类别概率对应的至少一个类别确定为所述待分类对象的第一类别。
可选地,所述基于所述第一类别以及所述中间特征,利用所述分类模型中的校验分支模型,得到所述待分类对象的第二类别,包括:基于所述第一类别以及所述中间特征,利用所述校验分支模型,获取所述待分类对象的至少一个第二类别概率;当所述至少一个第二类别概率大于或者等于所对应的第二阈值时,将所述至少一个第二类别概率对应的至少一个类别确定为所述待分类对象的第二类别。
根据本公开实施例的第二方面,提供一种分类模型的训练方法,包括:获取待分类对象样本,所述待分类对象样本对应有真实分类标签;基于所述待分类对象样本,通过所述分类模型中的分类分支模型,得到所述待分类对象样本的第一估计结果和中间特征,所述第一估计结果表示通过所述分类分支模型估计的所述待分类对象样本的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征;基于所述真实分类标签以及所述中间特征,通过所述分类模型中的校验分支模型,得到所述待分类对象样本的第二估计结果,所述第二估计结果表示通过所述校验分支模型估计的所述待分类对象样本的至少一个类别;根据所述第一估计结果、所述第二估计结果以及所述真实分类标签计算损失;通过根据所述损失调整所述分类分支模型和校验分支模型的模型参数,对所述分类模型进行训练。
可选地,所述校验分支模型包括类别特征获取模块、特征交互模块和分类模块,所述基于所述真实分类标签以及所述中间特征,通过所述校验分支模型,得到所述待分类对象样本的第二估计结果,包括:基于所述真实分类标签,利用所述类别特征获取模块,获取所述真实分类标签的类别特征;基于所述类别特征以及所述中间特征,利用所述特征交互模块,得到配置权重的中间特征;基于所述配置权重的中间特征,利用所述分类模块,得到所述待分类对象样本的第二估计结果。
可选地,所述基于所述真实分类标签,利用所述类别特征获取模块,获取所述真实分类标签的类别特征,包括:对所述真实分类标签进行单热编码,得到所述真实分类标签的类别列表;根据所述类别列表从字典矩阵中查询得到所述类别特征,所述字典矩阵是维度空间中的向量表示,所述字典矩阵的行数与所述类别列表的列数相同,所述字典矩阵的列数与所述中间特征的维度相同。
可选地,所述根据所述第一估计结果、所述第二估计结果以及所述真实分类标签计算损失,包括:根据所述第一估计结果和所述真实分类标签,计算第一损失;根据所述第二估计结果和所述真实分类标签,计算第二损失;根据所述第一损失和所述第二损失,得到所述损失。
可选地,根据所述第一损失和所述第二损失,得到所述损失,包括:将所述第一损失和所述第二损失的加权和作为所述损失,其中,所述第一损失的权重大于所述第二损失的权重。
根据本公开实施例的第三方面,提供一种分类装置,包括:对象获取单元,被配置为:获取待分类对象;第一类别获取单元,被配置为:基于所述待分类对象,利用分类模型中的分类分支模型,得到第一类别和中间特征,所述第一类别为通过所述分类分支模型估计的所述待分类对象的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征;第二类别获取单元,被配置为:基于所述第一类别以及所述中间特征,利用所述分类模型中的校验分支模型,得到所述待分类对象的第二类别,所述第二类别为通过所述校验分支模型对所述第一类别进行校验之后得到的至少一个类别;类别确定单元,被配置为:将所述第二类别确定为所述待分类对象的类别。
可选地,所述校验分支模型包括类别特征获取模块、特征交互模块和分类模块,所述第二类别获取单元被配置为:基于所述第一类别,利用所述类别特征获取模块,得到所述第一类别的类别特征;基于所述类别特征以及所述中间特征,利用所述特征交互模块,得到配置权重的中间特征,所述权重表示所述类别特征与所述中间特征之间的相关程度;基于所述配置权重的中间特征,利用所述分类模块,得到所述待分类对象的第二类别。
可选地,所述第二类别获取单元被配置为:对所述第一类别进行单热编码,得到所述第一类别的类别列表;根据所述类别列表从字典矩阵中查询得到所述类别特征,所述字典矩阵是维度空间中的向量表示,所述字典矩阵的行数与所述类别列表的列数相同,所述字典矩阵的列数与所述中间特征的维度相同。
可选地,所述第一类别获取单元被配置为:基于所述待分类对象,利用所述分类分支模型,获取所述待分类对象的至少一个第一类别概率;当所述至少一个第一类别概率大于或者等于所对应的第一阈值时,将所述至少一个第一类别概率对应的至少一个类别确定为所述待分类对象的第一类别。
可选地,所述第二类别获取单元被配置为:基于所述第一类别以及所述中间特征,利用所述校验分支模型,获取所述待分类对象的至少一个第二类别概率;当所述至少一个第二类别概率大于或者等于所对应的第二阈值时,将所述至少一个第二类别概率对应的至少一个类别确定为所述待分类对象的第二类别。
根据本公开实施例的第四方面,提供一种分类模型的训练装置,包括:样本获取单元,被配置为:获取待分类对象样本,所述待分类对象样本对应有真实分类标签;第一估计单元,被配置为:基于所述待分类对象样本,通过所述分类模型中的分类分支模型,得到所述待分类对象样本的第一估计结果和中间特征,所述第一估计结果表示通过所述分类分支模型估计的所述待分类对象样本的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征;第二估计单元,被配置为:基于所述真实分类标签以及所述中间特征,通过所述分类模型中的校验分支模型,得到所述待分类对象样本的第二估计结果,所述第二估计结果表示通过所述校验分支模型估计的所述待分类对象样本的至少一个类别;损失计算单元,被配置为:根据所述第一估计结果、所述第二估计结果以及所述真实分类标签计算损失;模型训练单元,被配置为:通过根据所述损失调整所述分类分支模型和校验分支模型的模型参数,对所述分类模型进行训练。
可选地,所述校验分支模型包括类别特征获取模块、特征交互模块和分类模块,所述第二估计单元被配置为:基于所述真实分类标签,利用所述类别特征获取模块,获取所述真实分类标签的类别特征;基于所述类别特征以及所述中间特征,利用所述特征交互模块,得到配置权重的中间特征;基于所述配置权重的中间特征,利用所述分类模块,得到所述待分类对象样本的第二估计结果。
可选地,所述第二估计单元被配置为:对所述真实分类标签进行单热编码,得到所述真实分类标签的类别列表;根据所述类别列表从字典矩阵中查询得到所述类别特征,所述字典矩阵是维度空间中的向量表示,所述字典矩阵的行数与所述稀疏矩阵的列数相同,所述字典矩阵的列数与所述中间特征的维度相同。
可选地,所述损失计算单元被配置为:根据所述第一估计结果和所述真实分类标签,计算第一损失;根据所述第二估计结果和所述真实分类标签,计算第二损失;根据所述第一损失和所述第二损失,得到所述损失。
可选地,所述模型训练单元被配置为:将所述第一损失和所述第二损失的加权和作为所述损失,其中,所述第一损失的权重大于所述第二损失的权重。
根据本公开实施例的第五方面,提供一种电子设备,包括:至少一个处理器;至少一个存储计算机可执行指令的存储器,其中,所述计算机可执行指令在被所述至少一个处理器运行时,促使所述至少一个处理器执行根据本公开的第一方面的分类方法或第二方面的分类模型的训练方法。
根据本公开实施例的第六方面,提供一种存储指令的计算机可读存储介质,当所述指令被至少一个处理器运行时,促使所述至少一个处理器执行根据本公开的第一方面的分类方法或第二方面的分类模型的训练方法。
根据本公开实施例的第七方面,提供一种计算机程序产品,该计算机程序产品中的指令可由计算机设备的处理器执行以完成根据本公开的第一方面的分类方法或第二方面的分类模型的训练方法。
本公开的实施例提供的技术方案至少带来以下有益效果:
根据本公开的分类及分类模型的训练方法、装置、设备及介质,在传统的分类模型基础上增加校验分支模型,从而实现对分类分支模型的分类结果进行二次校验,可进一步提高分类分支模型的准召率,提高分类的准确性。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理,并不构成对本公开的不当限定。
图1是示出根据本公开的示例性实施例的分类模型的结构示意图。
图2是示出根据本公开的示例性实施例的分类方法的流程图。
图3是示出根据本公开的示例性实施例的分类分支模型的结构示意图。
图4是示出根据本公开的示例性实施例的分类模型的训练方法的流程图。
图5是示出根据本公开的示例性实施例的分类模型的训练以及运用的整体示意图。
图6是示出根据本公开的示例性实施例的分类装置的框图。
图7是示出根据本公开的示例性实施例的分类模型的训练装置的框图。
图8是示出根据本公开的示例性实施例的电子设备800的框图。
具体实施方式
为了使本领域普通人员更好地理解本公开的技术方案,下面将结合附图,对本公开实施例中的技术方案进行清楚、完整地描述。
需要说明的是,本公开的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本公开的实施例能够以除了在这里图示或描述的那些以外的顺序实施。以下实施例中所描述的实施方式并不代表与本公开相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本公开的一些方面相一致的装置和方法的例子。
在此需要说明的是,在本公开中出现的“若干项之中的至少一项”均表示包含“该若干项中的任意一项”、“该若干项中的任意多项的组合”、“该若干项的全体”这三类并列的情况。例如“包括A和B之中的至少一个”即包括如下三种并列的情况:(1)包括A;(2)包括B;(3)包括A和B。又例如“执行步骤一和步骤二之中的至少一个”,即表示如下三种并列的情况:(1)执行步骤一;(2)执行步骤二;(3)执行步骤一和步骤二。
目前,随着人工智能技术的高速发展,越来越多的应用场景通过调用分类模型对目标对象进行分类,分类模型的分类准确性直接影响针对目标对象的后续执行动作,例如,在视频精准推送或者视频内容审核场景中,视频分类的准确度可直接影响推送效果或者审核效率,因此,需不断提高分类模型的分类准确性。
相关技术中存在一种图像多分类模型:原始图像经多层卷积神经网络提取图像特征之后,接多层FC(Fully Connected Layer,全连接层)进行特征压缩和整合,最后接类别数为C的FC层实现图像的多分类。但相关技术中的图像多分类模型存在分类不准确的现象,导致应用于具体的场景时的分类效果不够理想。
为了提高分类的准确性,本公开提出了分类及分类模型的训练方法、装置、设备及介质,在传统的分类模型基础上增加校验分支模型,从而实现对分类分支模型的分类结果进行二次校验,可进一步提高分类分支模型的准召率,提高分类的准确性。下面,将参照图1至图8具体描述根据本公开的示例性实施例的分类方法及装置和分类模型的训练方法及装置。
图1是示出根据本公开的示例性实施例的分类模型的结构示意图。
这里,分类模型可用于对图像进行分类,例如,可对视频、图片以及文本等进行分类,在此不作限制,而分类可以是多标签分类,例如,可以是针对视频的多标签分类。参照图1,根据本公开的分类模型100包括分类分支模型101和校验分支模型102,其中,将待分类对象(视频或者图片等)进行相关预处理之后(或者也可不进行预处理)输入分类分支模型101,可得到该待分类对象的估计类别,相关技术中直接将该估计类别作为最终的分类结果,但该估计类别与待分类对象的真实类别通常存在偏差(即,存在分类不准确的情况),因此,本公开设置通过校验分支模型102对该估计类别进行验证,可去除该估计类别中准确度偏低(例如,预测概率小于校验分支模型102的预设阈值)的类别,而将准确度偏高(例如,预测概率大于或者等于校验分支模型102的预设阈值)的类别输出,作为最终的分类结果。
由于本公开的分类模型包括分类分支模型和校验分支模型,可实现通过校验分支模型对分类分支模型输出的估计类别进行二次校验,提高了分类模型的分类准确性,进而可提高在各种应用场景中针对目标对象进行分类的准确性。
图2是示出根据本公开的示例性实施例的分类方法的流程图。这里,分类方法基于前述的分类模型执行,分类模型包括分类分支模型和校验分支模型。
参照图2,在步骤201,可获取待分类对象。这里,待分类对象可以是视频、图片或者文本等,对此不作限制,分类是指对待分类对象进行多标签分类,即,待分类对象的类别可能有多个,例如,一张主题为航海的图片的类别可能为大海和船只。可针对视频进行多标签分类,在该场景下,可获取待分类的视频。
在步骤202,可基于所述待分类对象,利用分类模型中的分类分支模型,得到所述待分类对象的第一类别和中间特征。这里,第一类别为通过所述分类分支模型估计的所述待分类对象的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征。
根据本公开的示例性实施例,可基于待分类对象,利用分类分支模型,获取待分类对象的至少一个第一类别概率,当至少一个第一类别概率大于或者等于所对应的第一阈值时,将至少一个第一类别概率对应的至少一个类别确定为待分类对象的第一类别,这里,至少一个第一类别所对应的第一阈值相同或者相异。具体来讲,分类分支模型包括多个输出通道,每个输出通道对应一种类别,每个输出通道包括经分类分支模型预测的该通道所对应的类别的多个预测概率,将其中的概率最大值作为该输出通道的预测概率,当预测概率大于或者等于该输出通道的第一阈值时,可判定该输出通道对应的类别为待分类对象的类别。分类分支模型可以是相关技术中的任一种分类模型,例如,在针对视频进行多标签分类的场景中,分类分支模型结构可如图3所示。图3是示出根据本公开的示例性实施例的分类分支模型的结构示意图。参照图3,首先从待分类的视频中采样M帧全局图像,M为大于等于1的自然数,若无法采样到M帧图像,则将视频的封面帧图像补入采样图像,这里,采样的帧数与分类分支模型的结构相关;然后,将M帧图像输入图像分类网络,提取图像特征,这里,图像分类网络可采用Inception-V3、ResNet-50d或EfficientNet-B3等,对此不作限制,在图3中,示例性采用EfficientNet-B3。假设特征维度为D维,则经过EfficientNet-B3提取的图像特征维度为M*D维;接下来,将M*D维的特征输入transformer结构进行特征交互,得到D维的内容特征(即中间特征),最后,将该D维的内容特征经过多层FC,可得到视频的至少一个类别(即,第一类别)。
在步骤203中,可基于所述第一类别以及所述中间特征,利用所述分类模型中的校验分支模型,得到第二类别。这里,第二类别为通过所述校验分支模型对所述第一类别进行校验之后得到的至少一个类别。
根据本公开的示例性实施例,校验分支模型可包括类别特征获取模块、特征交互模块和分类模块,首先,可基于第一类别,利用类别特征获取模块,得到第一类别的类别特征。在一些实施例中,类别特征获取模块可对第一类别进行单热编码,得到第一类别的类别列表;并根据该类别列表从字典矩阵中查询得到类别特征,其中,字典矩阵是维度空间中的向量表示,字典矩阵的行数与稀疏矩阵的列数相同,字典矩阵的列数与通过分类分支模型的中间层得到的中间特征的维度相同,这里,通过对第一类别进行单热编码并通过编码得到的类别列表查询训练好的字典矩阵,可得到在维度空间中被准确表达的第一类别,从而可提高校验分支模型的校验准确性。具体来讲,可首先对第一类别进行one-hot编码,例如,可编码为(0,1,2,……,T-1),共T个类别,然后获取字典矩阵(T,D)对应位置的embedding(特征向量),例如,如果编码为0,则取字典矩阵的第一行embedding,得到长度为D的数组,以此类推,可得到第一类别中的所有类别的类别特征,这里,字典矩阵经过训练之后,能在极大程度上在维度空间中对第一类别。在得到第一类别的类别特征之后,可基于该类别特征以及在步骤202中通过分类分支模型的中间层得到的中间特征,利用特征交互模块,得到配置权重的中间特征,这里,权重表示该类别特征与中间特征之间的相关程度。具体来讲,可通过multi-head attention(多头注意力)机制进行特征交互,将类别特征作为Attention(Q,K,V)中的Q(Query),中间特征作为K(Key)和V(Value),获取类别特征与中间特征之间的相似度权重,并将该权重叠加至中间特征,得到配置权重的中间特征,即,可实现不同的类别特征激活不同的中间特征,如果第一类别中存在判定不准确的类别,则其与中间特征之间的相似度偏低,相应的权重也偏低,反之,相应的权重偏高;最后,可基于配置权重的中间特征,利用分类模块,得到第二类别。具体来讲,可将配置权重的中间特征经过多层FC,得到第二类别,这里,由于中间特征配置有权重,因此权重偏低的中间特征经过多层FC之后被判定为某个类别的概率亦对应降低,如果低于该某个类别的阈值概率,则最终不会输出该类别,从而达到校验的目的,提高了分类的准确性。
根据本公开的示例性实施例,可基于第一类别以及中间特征,利用校验分支模型,获取待分类对象的至少一个第二类别概率,当至少一个第二类别概率大于或者等于所对应的第二阈值时,将至少一个第二类别概率对应的至少一个类别确定为待分类对象的第二类别,这里,至少一个第二类别所对应的第二阈值相同或者相异。具体来讲,校验分支模型包括多个输出通道,每个输出通道对应一种类别,每个输出通道包括经校验分支模型预测的该通道所对应的类别的多个预测概率,将其中的概率最大值作为该输出通道的预测概率,当预测概率大于或者等于该输出通道的第二阈值时,可判定该输出通道对应的类别确实为待分类对象的类别。
在步骤204,可将所述第二类别确定为所述待分类对象的类别。这里,第二类别是经过校验分支模型校验之后的类别,因此准确性更高。
图4是示出根据本公开的示例性实施例的分类模型的训练方法的流程图。这里,分类模型是指待训练的用于对待分类对象进行分类的模型,包括分类分支模型和校验分支模型。
参照图4,在步骤401,可获取待分类对象样本,所述待分类对象样本对应有真实分类标签。这里,待分类对象样本是指用于训练分类模型的对象样本,例如,视频样本、图片样本或者文本样本等,可根据分类模型的具体应用场景获取对应的待分类对象样本,对此不作限制。待分类对象样本的数量可为一个,也可为多个,示例性地,为保证分类模型的训练效果,待分类对象样本的数量可为多个。待分类对象样本对应有真实分类标签,真实分类标签用于指示待分类对象的真实类别,该真实类别用于描述待分类对象中实际包含的内容,这里,待分类对象的真实标签为多个,例如,一张主题为航海的图片的类别可能为大海和船只。在针对视频多标签分类的场景中,可获取视频样本,视频样本的数量可为一个,也可为多个,示例性地,为保证模型训练效果,视频样本的数量可为多个。
在步骤402,可基于待分类对象样本,通过分类模型中的分类分支模型,得到待分类对象样本的第一估计结果和中间特征,第一估计结果表示通过分类分支模型估计的待分类对象样本的至少一个类别,中间特征为通过分类分支模型的中间层输出的特征。这里,分类分支模型可以是相关技术中的任一种分类模型,例如,在针对视频进行多标签分类的场景中,分类分支模型结构可如图3所示。参照图3,首先从视频样本中采样M帧全局图像,M为大于等于1的自然数,若无法采样到M帧图像,则将视频的封面帧图像补入采样图像,这里,采样的帧数与分类分支模型的结构相关;然后,将M帧图像输入图像分类网络,以提取图像特征,这里,图像分类网络可采用Inception-V3、ResNet-50d或EfficientNet-B3等,对此不作限制,在图3中,示例性采用EfficientNet-B3。假设特征维度为D维,则经过EfficientNet-B3提取的图像特征维度为M*D维;接下来,将M*D维的特征输入transformer结构进行特征交互,得到D维的内容特征(即,中间特征),最后,将该D维的内容特征经过多层FC,可得到估计的视频样本的类别(即,第一估计结果)。
在步骤403,可基于所述真实分类标签以及所述中间特征,通过分类模型中的校验分支模型,得到待分类对象样本的第二估计结果,第二估计结果表示通过校验分支模型估计的待分类对象样本的类别。
根据本公开的示例性实施例,校验分支模型可包括类别特征获取模块、特征交互模块和分类模块,首先,可基于待分类对象的真实分类标签,利用类别特征获取模块,得到第一类别的类别特征。在一些实施例中,类别特征获取模块可对真实分类标签进行单热编码,得到真实类别标签的类别列表,并可根据该类别列表以及字典矩阵,得到真实类别标签的类别特征,其中,字典矩阵是维度空间中的向量表示,字典矩阵是随机初始化的一个矩阵,该字典矩阵的行数与类别列表的列数相同,列数与通过分类分支模型的中间层得到的中间特征的维度相同。具体来讲,可首先对真实分类标签进行one-hot编码,例如,可编码为(0,1,2,……,T-1),共T个真实类别,然后获取字典矩阵(T,D)对应位置的embedding(特征向量),例如,如果编码为0,则取字典矩阵的第一行embedding,得到长度为D的数组,以此类推,可得到真实类别标签中的所有类别的类别特征。在得到真实分类标签的类别特征之后,可基于该类别特征以及通过分类分支模型的中间层得到的中间特征,利用特征交互模块,得到配置权重的中间特征,这里,权重表示该类别特征与中间特征之间的相关程度。具体来讲,可通过multi-head attention(多头注意力)机制进行特征交互,将类别特征作为Attention(Q,K,V)中的Q(Query),中间特征作为K(Key)和V(Value),获取类别特征与中间特征之间的相似度权重,并将该权重叠加至中间特征,得到配置权重的中间特征,将配置权重的中间特征经过多层FC,可得到估计的待分类对象的类别。具体来讲,由于字典矩阵是随机初始化得到的,在训练的初始阶段可能不能很好地表示真实分类标签的特征,因此尽管校验分支模型的输入为真实的分类标签,但在特征交互模块之中,分类特征与中间特征之间的相关程度在训练的初始阶段可能不高,通过反复训练调整相关参数,可获得能很好地反映真实分类标签特征的字典矩阵。
在步骤404,可根据第一估计结果、第二估计结果以及真实分类标签计算损失。这里,损失用以衡量分类模型的估计的待分类对象的类别与真实分类标签之间的差距,可通过损失函数计算该损失,而损失函数,例如,但不限于,可以是交叉熵损失函数。根据本公开的示例性实施例,可首先根据第一估计结果和真实分类标签计算第一损失,然后根据第二估计结果和真实分类标签计算第二损失,最后根据第一损失和第二损失,得到分类模型的总的损失,具体来讲,可将第一损失和第二损失的加权和作为该损失,其中,第一损失的权重大于第二损失的权重。这里,将第一损失和第二损失的加权和作为分类模型的损失对分类模型进行训练,可兼顾分类分支模型的分类效果和校验分支模型的校验效果,并且,使第一损失的权重大于第二损失的权重,可保证分类分支模型的主要分类效果以及校验分支模型的辅助分类效果。示例性地,总的损失与第一损失和第二损失之间的关系可被表示为:
Loss=a*Lossc+b*Losst
其中,Loss表示分类模型的总的损失;Lossc表示第一损失;Losst表示第二损失;a表示第一损失的权重;b表示第二损失的权重。示例性地,a=1;b=0.4,a和b的取值可在a>b的前提下动态调整。
在步骤405,通过根据所述损失调整所述分类分支模型和校验分支模型的模型参数,训练所述分类模型。也就是说,可通过损失函数计算的损失反向传播来调整分类模型的参数。此外,在模型的训练过程中,可使用批量的待分类对象样本(例如,在针对视频多标签分类场景中,可使用批量视频)来调整或更新模型参数,并以最小化损失函数的值为目标,迭代地调整或更新模型参数,直至损失不再下降或达到迭代次数。
图5是示出根据本公开的示例性实施例的分类模型的训练以及运用的整体示意图。
参照图5(a),分类模型包括分类分支模型和校验分支模型,训练时,将待分类对象样本输入分类分支模型,得到中间特征(由中间层对待分类对象样本进行特征提取得到)和第一估计结果,通过第一估计结果和待分类对象样本对应的真实分类标签计算第一损失;将真实分类标签和中间特征输入校验分支模型,得到第二估计结果,通过第二估计结果和待分类对象样本对应的真实分类标签计算第二损失,通过第一损失的值调整分类分支模型的模型参数,并通过第二损失的值调整校验分支模型的模型参数,来训练得到分类模型。参照图5(b),可将待分类对象输入分类分支模型,得到预测的第一类别,将预测的第一类别以及通过分类分支模型得到的中间特征输入校验分支模型,得到对第一类别进行校验之后的第二类别。这里,由于在传统的分类模型基础上增加校验分支模型,从而实现对分类分支模型的分类结果进行二次校验,可进一步提高分类分支模型的准召率,提高了分类的准确性。
图6是示出根据本公开的示例性实施例的分类装置的框图。
参照图6,根据本公开的示例性实施例的分类装置600基于分类模型执行操作,分类装置600可包括对象获取单元601、第一类别获取单元602、第二类别获取单元603和类别确定单元604。
对象获取单元601可获取待分类对象。第一类别获取单元602可基于待分类对象,利用分类模型中的分类分支模型,得到待分类对象的第一类别和中间特征,这里,第一类别为通过分类分支模型估计的待分类对象的至少一个类别,中间特征为通过分类分支模型的中间层输出的特征。第二类别获取单元602可基于第一类别以及中间特征,利用分类模型中校验分支模型,得到第二类别,这里第二类别为通过校验分支模型对第一类别进行校验之后得到的至少一个类别。类别确定单元604可将第二类别确定为待分类对象的类别。
由于图2所示的分类方法可由图6所示的分类装置600来执行,并且对象获取单元601、第一类别获取单元602、第二类别获取单元603和类别确定单元604可分别执行与图2中的步骤201、步骤202、步骤203和步骤204对应的操作,因此,关于图6中的各单元所执行的操作中涉及的任何相关细节均可参见关于图2的相应描述,这里都不再赘述。
此外,需要说明的是,尽管以上在描述分类装置600时将其划分为用于分别执行相应处理的单元,然而,本领域技术人员清楚的是,上述各单元执行的处理也可以在分类装置600不进行任何具体单元划分或者各单元之间并无明确划界的情况下执行。此外,分类装置600还可包括其他单元,例如,数据处理单元、存储单元等。
图7是示出根据本公开的示例性实施例的分类模型的训练装置的框图。这里,分类模型包括分类分支模型和校验分支模型。
参照图7,根据本公开的示例性实施例的分类模型的训练装置700可包括样本获取单元701、第一估计单元702、第二估计单元703、损失计算单元704和模型训练单元705。
样本获取单元701可获取待分类对象样本,该待分类对象样本对应有真实分类标签。第一估计单元702可基于待分类对象样本,通过分类模型中的分类分支模型,得到待分类对象样本的第一估计结果和中间特征,这里,第一估计结果表示通过分类分支模型估计的待分类对象样本的至少一个类别,中间特征为通过分类分支模型的中间层输出的特征。第二估计单元703可基于真实分类标签以及中间特征,通过分类模型中的校验分支模型,得到待分类对象样本的第二估计结果,这里,第二估计结果表示通过校验分支模型估计的待分类对象样本的至少一个类别。损失计算单元704可根据第一估计结果、第二估计结果以及真实分类标签计算损失。模型训练单元705可通过根据损失调整分类分支模型和校验分支模型的模型参数,对分类模型进行训练。
由于图4所示的分类模型的训练方法可由图7所示的分类模型的训练装置700来执行,并且样本获取单元701、第一估计单元702、第二估计单元703、损失计算单元704和模型训练单元705可分别执行与图4中的步骤401、步骤402、步骤403、步骤404和405对应的操作,因此,关于图7中的各单元所执行的操作中涉及的任何相关细节均可参见关于图4的相应描述,这里都不再赘述。
此外,需要说明的是,尽管以上在描述分类模型的训练装置700时将其划分为用于分别执行相应处理的单元,然而,本领域技术人员清楚的是,上述各单元执行的处理也可以在分类模型的训练装置700不进行任何具体单元划分或者各单元之间并无明确划界的情况下执行。此外,分类模型的训练装置700还可包括其他单元,例如,数据处理单元、存储单元等。
图8是根据本公开的示例性实施例的电子设备800的框图。
参照图8,电子设备800包括至少一个存储器801和至少一个处理器802,所述至少一个存储器801中存储有计算机可执行指令集合,当计算机可执行指令集合被至少一个处理器802执行时,执行根据本公开的示例性实施例的分类方法或分类模型的训练方法。
作为示例,电子设备800可以是PC计算机、平板装置、个人数字助理、智能手机、或其他能够执行上述指令集合的装置。这里,电子设备800并非必须是单个的电子设备,还可以是任何能够单独或联合执行上述指令(或指令集)的装置或电路的集合体。电子设备800还可以是集成控制系统或系统管理器的一部分,或者可被配置为与本地或远程(例如,经由无线传输)以接口互联的便携式电子设备。
在电子设备800中,处理器802可包括中央处理器(CPU)、图形处理器(GPU)、可编程逻辑装置、专用处理器系统、微控制器或微处理器。作为示例而非限制,处理器还可包括模拟处理器、数字处理器、微处理器、多核处理器、处理器阵列、网络处理器等。
处理器802可运行存储在存储器801中的指令或代码,其中,存储器801还可以存储数据。指令和数据还可经由网络接口装置而通过网络被发送和接收,其中,网络接口装置可采用任何已知的传输协议。
存储器801可与处理器802集成为一体,例如,将RAM或闪存布置在集成电路微处理器等之内。此外,存储器801可包括独立的装置,诸如,外部盘驱动、存储阵列或任何数据库系统可使用的其他存储装置。存储器801和处理器802可在操作上进行耦合,或者可例如通过I/O端口、网络连接等互相通信,使得处理器802能够读取存储在存储器中的文件。
此外,电子设备800还可包括视频显示器(诸如,液晶显示器)和用户交互接口(诸如,键盘、鼠标、触摸输入装置等)。电子设备800的所有组件可经由总线和/或网络而彼此连接。
根据本公开的示例性实施例,还可提供一种存储指令的计算机可读存储介质,其中,当指令被至少一个处理器运行时,促使至少一个处理器执行根据本公开的分类方法或分类模型的训练方法。这里的计算机可读存储介质的示例包括:只读存储器(ROM)、随机存取可编程只读存储器(PROM)、电可擦除可编程只读存储器(EEPROM)、随机存取存储器(RAM)、动态随机存取存储器(DRAM)、静态随机存取存储器(SRAM)、闪存、非易失性存储器、CD-ROM、CD-R、CD+R、CD-RW、CD+RW、DVD-ROM、DVD-R、DVD+R、DVD-RW、DVD+RW、DVD-RAM、BD-ROM、BD-R、BD-R LTH、BD-RE、蓝光或光盘存储器、硬盘驱动器(HDD)、固态硬盘(SSD)、卡式存储器(诸如,多媒体卡、安全数字(SD)卡或极速数字(XD)卡)、磁带、软盘、磁光数据存储装置、光学数据存储装置、硬盘、固态盘以及任何其他装置,所述任何其他装置被配置为以非暂时性方式存储计算机程序以及任何相关联的数据、数据文件和数据结构并将所述计算机程序以及任何相关联的数据、数据文件和数据结构提供给处理器或计算机使得处理器或计算机能执行所述计算机程序。上述计算机可读存储介质中的计算机程序可在诸如客户端、主机、代理装置、服务器等计算机设备中部署的环境中运行,此外,在一个示例中,计算机程序以及任何相关联的数据、数据文件和数据结构分布在联网的计算机系统上,使得计算机程序以及任何相关联的数据、数据文件和数据结构通过一个或多个处理器或计算机以分布式方式存储、访问和执行。
根据本公开的示例性实施例,还可提供一种计算机程序产品,该计算机程序产品中的指令可由计算机设备的处理器执行以完成根据本公开的示例性实施例的分类方法或分类模型的训练方法。
根据本公开的分类及分类模型的训练方法、装置、设备及介质,在传统的分类模型基础上增加校验分支模型,从而实现对分类分支模型的分类结果进行二次校验,可进一步提高分类分支模型的准召率,提高分类的准确性。
本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本公开的其它实施方案。本申请旨在涵盖本公开的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本公开的一般性原理并包括本公开未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本公开的真正范围和精神由下面的权利要求指出。
应当理解的是,本公开并不局限于上面已经描述并在附图中示出的精确结构,并且可以在不脱离其范围进行各种修改和改变。本公开的范围仅由所附的权利要求来限制。

Claims (10)

1.一种分类方法,其特征在于,包括:
获取待分类对象;
基于所述待分类对象,利用分类模型中的分类分支模型,得到所述待分类对象的第一类别和中间特征,所述第一类别为通过所述分类分支模型估计的所述待分类对象的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征;
基于所述第一类别以及所述中间特征,利用所述分类模型中的校验分支模型,得到所述待分类对象的第二类别,所述第二类别为通过所述校验分支模型对所述第一类别进行校验之后得到的至少一个类别;
将所述第二类别确定为所述待分类对象的类别。
2.如权利要求1所述的分类方法,其特征在于,所述校验分支模型包括类别特征获取模块、特征交互模块和分类模块,所述基于所述第一类别以及所述中间特征,利用所述分类模型中的校验分支模型,得到所述待分类对象的第二类别,包括:
基于所述第一类别,利用所述类别特征获取模块,得到所述第一类别的类别特征;
基于所述类别特征以及所述中间特征,利用所述特征交互模块,得到配置权重的中间特征,所述权重表示所述类别特征与所述中间特征之间的相关程度;
基于所述配置权重的中间特征,利用所述分类模块,得到所述待分类对象的第二类别。
3.如权利要求2所述的分类方法,其特征在于,所述基于所述第一类别,利用所述类别特征获取模块,得到所述第一类别的类别特征,包括:
对所述第一类别进行单热编码,得到所述第一类别的类别列表;
根据所述类别列表从字典矩阵中查询得到所述类别特征,所述字典矩阵是维度空间中的向量表示,所述字典矩阵的行数与所述类别列表的列数相同,所述字典矩阵的列数与所述中间特征的维度相同。
4.如权利要求1所述的分类方法,其特征在于,所述基于所述待分类对象,利用分类模型中的分类分支模型,得到所述待分类对象的第一类别,包括:
基于所述待分类对象,利用所述分类分支模型,获取所述待分类对象的至少一个第一类别概率;
当所述至少一个第一类别概率大于或者等于所对应的第一阈值时,将所述至少一个第一类别概率对应的至少一个类别确定为所述待分类对象的第一类别。
5.一种分类模型的训练方法,其特征在于,包括:
获取待分类对象样本,所述待分类对象样本对应有真实分类标签;
基于所述待分类对象样本,通过所述分类模型中的分类分支模型,得到所述待分类对象样本的第一估计结果和中间特征,所述第一估计结果表示通过所述分类分支模型估计的所述待分类对象样本的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征;
基于所述真实分类标签以及所述中间特征,通过所述分类模型中的校验分支模型,得到所述待分类对象样本的第二估计结果,所述第二估计结果表示通过所述校验分支模型估计的所述待分类对象样本的至少一个类别;
根据所述第一估计结果、所述第二估计结果以及所述真实分类标签计算损失;
通过根据所述损失调整所述分类分支模型和校验分支模型的模型参数,对所述分类模型进行训练。
6.一种分类装置,其特征在于,包括:
对象获取单元,被配置为:获取待分类对象;
第一类别获取单元,被配置为:基于所述待分类对象,利用分类模型中的分类分支模型,得到所述待分类对象的第一类别和中间特征,所述第一类别为通过所述分类分支模型估计的所述待分类对象的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征;
第二类别获取单元,被配置为:基于所述第一类别以及所述中间特征,利用所述分类模型中的校验分支模型,得到所述待分类对象的第二类别,所述第二类别为通过所述校验分支模型对所述第一类别进行校验之后得到的至少一个类别;
类别确定单元,被配置为:将所述第二类别确定为所述待分类对象的类别。
7.一种分类模型的训练装置,其特征在于,包括:
样本获取单元,被配置为:获取待分类对象样本,所述待分类对象样本对应有真实分类标签;
第一估计单元,被配置为:基于所述待分类对象样本,通过所述分类模型中的分类分支模型,得到所述待分类对象样本的第一估计结果和中间特征,所述第一估计结果表示通过所述分类分支模型估计的所述待分类对象样本的至少一个类别,所述中间特征为通过所述分类分支模型的中间层输出的特征;
第二估计单元,被配置为:基于所述真实分类标签以及所述中间特征,通过所述分类模型中的校验分支模型,得到所述待分类对象样本的第二估计结果,所述第二估计结果表示通过所述校验分支模型估计的所述待分类对象样本的至少一个类别;
损失计算单元,被配置为:根据所述第一估计结果、所述第二估计结果以及所述真实分类标签计算损失;
模型训练单元,被配置为:通过根据所述损失调整所述分类分支模型和校验分支模型的模型参数,对所述分类模型进行训练。
8.一种电子设备,其特征在于,包括:
至少一个处理器;
至少一个存储计算机可执行指令的存储器,
其中,所述计算机可执行指令在被所述至少一个处理器运行时,促使所述至少一个处理器执行如权利要求1到4中的任一权利要求所述的分类方法或如权利要求5所述的分类模型的训练方法。
9.一种存储指令的计算机可读存储介质,其特征在于,当所述指令被至少一个处理器运行时,促使所述至少一个处理器执行如权利要求1到4中的任一权利要求所述的分类方法或如权利要求5所述的分类模型的训练方法。
10.一种计算机程序产品,包括计算机指令,其特征在于,所述计算机指令被至少一个处理器执行时实现如权利要求1到4中的任一权利要求所述的分类方法或如权利要求5所述的分类模型的训练方法。
CN202210021106.4A 2022-01-10 2022-01-10 分类及分类模型的训练方法、装置、设备及介质 Pending CN114359809A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210021106.4A CN114359809A (zh) 2022-01-10 2022-01-10 分类及分类模型的训练方法、装置、设备及介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210021106.4A CN114359809A (zh) 2022-01-10 2022-01-10 分类及分类模型的训练方法、装置、设备及介质

Publications (1)

Publication Number Publication Date
CN114359809A true CN114359809A (zh) 2022-04-15

Family

ID=81107168

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210021106.4A Pending CN114359809A (zh) 2022-01-10 2022-01-10 分类及分类模型的训练方法、装置、设备及介质

Country Status (1)

Country Link
CN (1) CN114359809A (zh)

Similar Documents

Publication Publication Date Title
Shivakumara et al. CNN‐RNN based method for license plate recognition
US20230376527A1 (en) Generating congruous metadata for multimedia
CN111382555B (zh) 数据处理方法、介质、装置和计算设备
CN112633419A (zh) 小样本学习方法、装置、电子设备和存储介质
CN112789626A (zh) 可扩展和压缩的神经网络数据储存系统
CN115443490A (zh) 影像审核方法及装置、设备、存储介质
CN111783712A (zh) 一种视频处理方法、装置、设备及介质
CN116303459A (zh) 处理数据表的方法及系统
CN114003758B (zh) 图像检索模型的训练方法和装置以及检索方法和装置
CN113435499A (zh) 标签分类方法、装置、电子设备和存储介质
CN111382620A (zh) 视频标签添加方法、计算机存储介质和电子设备
CN117011737A (zh) 一种视频分类方法、装置、电子设备和存储介质
CN117251761A (zh) 数据对象分类方法、装置、存储介质及电子装置
CN113255824B (zh) 训练分类模型和数据分类的方法和装置
CN116958622A (zh) 数据的分类方法、装置、设备、介质及程序产品
CN116955707A (zh) 内容标签的确定方法、装置、设备、介质及程序产品
CN114359809A (zh) 分类及分类模型的训练方法、装置、设备及介质
CN115269998A (zh) 信息推荐方法、装置、电子设备及存储介质
CN115080856A (zh) 推荐方法及装置、推荐模型的训练方法及装置
Nag et al. CNN based approach for post disaster damage assessment
CN111611981A (zh) 信息识别方法和装置及信息识别神经网络训练方法和装置
CN117058432B (zh) 图像查重方法、装置、电子设备及可读存储介质
CN116578867A (zh) 标识生成方法及电子设备
CN117011539A (zh) 目标检测方法、目标检测模型的训练方法、装置及设备
CN114371937A (zh) 模型训练方法、多任务联合预测方法、装置、设备及介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination