CN117315685A - 分类模型训练方法、分类方法、装置及电子设备 - Google Patents

分类模型训练方法、分类方法、装置及电子设备 Download PDF

Info

Publication number
CN117315685A
CN117315685A CN202311279540.3A CN202311279540A CN117315685A CN 117315685 A CN117315685 A CN 117315685A CN 202311279540 A CN202311279540 A CN 202311279540A CN 117315685 A CN117315685 A CN 117315685A
Authority
CN
China
Prior art keywords
prompt
image
sample
text
dimension
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202311279540.3A
Other languages
English (en)
Inventor
杜俊珑
鄢科
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tencent Technology Shenzhen Co Ltd
Original Assignee
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tencent Technology Shenzhen Co Ltd filed Critical Tencent Technology Shenzhen Co Ltd
Priority to CN202311279540.3A priority Critical patent/CN117315685A/zh
Publication of CN117315685A publication Critical patent/CN117315685A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V30/00Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
    • G06V30/10Character recognition
    • G06V30/19Recognition using electronic means
    • G06V30/191Design or setup of recognition systems or techniques; Extraction of features in feature space; Clustering techniques; Blind source separation
    • G06V30/19147Obtaining sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V30/00Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
    • G06V30/10Character recognition
    • G06V30/18Extraction of features or characteristics of the image
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V30/00Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
    • G06V30/10Character recognition
    • G06V30/19Recognition using electronic means
    • G06V30/19007Matching; Proximity measures
    • G06V30/19093Proximity measures, i.e. similarity or distance measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V30/00Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
    • G06V30/10Character recognition
    • G06V30/19Recognition using electronic means
    • G06V30/191Design or setup of recognition systems or techniques; Extraction of features in feature space; Clustering techniques; Blind source separation
    • G06V30/19173Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V30/00Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
    • G06V30/10Character recognition
    • G06V30/19Recognition using electronic means
    • G06V30/191Design or setup of recognition systems or techniques; Extraction of features in feature space; Clustering techniques; Blind source separation
    • G06V30/1918Fusion techniques, i.e. combining data from various sources, e.g. sensor fusion

Landscapes

  • Engineering & Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Multimedia (AREA)
  • Theoretical Computer Science (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请提供了一种分类模型训练方法、分类方法、装置及电子设备,可应用于云技术、人工智能、智慧交通、辅助驾驶等各种场景,方法包括:利用分类模型中的图像特征提取网络提取样本图像特征;利用分类模型中的提示生成网络基于样本图像特征,生成与多个提示文本分别对应的样本提示特征;利用分类模型中的文本特征提取网络基于参考提示特征和样本提示特征,生成各提示文本在每个维度下的样本融合特征;基于各提示文本在每个维度下的样本融合特征、样本图像特征以及样本图像在至少一个维度下的样本类别标签,确定第一模型损失;基于第一模型损失调整分类模型的模型参数以及提示文本。通过采用上述方法,可以有效提升训练后的分类模型的分类效果。

Description

分类模型训练方法、分类方法、装置及电子设备
技术领域
本申请涉及人工智能技术领域,更具体地,涉及一种分类模型训练方法、分类方法、装置及电子设备。
背景技术
目前,在内容审核(如,图像审核)应用场景下,所涉及的分类任务多、标签数多且业务场景复杂,对于运营时效性要求高。若直接使用单模型进行分类,则分类效果难以应对复杂多变的应用场景,故而存在分类结果不够准确的问题。
发明内容
有鉴于此,本申请实施例提出了一种分类模型训练方法、分类方法、装置及电子设备,可以有效提升训练后的分类模型的准确性。
第一方面,本申请实施例提供了一种分类模型训练方法,该方法包括:获取训练数据,所述训练数据包括多张样本图像,所述样本图像在多个维度中至少一个维度下具有样本类别标签;利用分类模型中的图像特征提取网络对所述样本图像进行特征提取,得到样本图像特征;利用分类模型中的提示生成网络基于所述样本图像特征,生成与多个提示文本分别对应的样本提示特征;利用分类模型中的文本特征提取网络基于所述多个提示文本分别对应的参考提示特征和所述多个提示文本分别对应的样本提示特征,生成所述多个提示文本在所述多个维度中每个维度下的样本融合特征;基于所述多个提示文本在所述多个维度中每个维度下的所述样本融合特征、所述样本图像特征以及所述样本图像在至少一个维度下的样本类别标签,确定第一模型损失;基于所述第一模型损失调整所述分类模型的模型参数以及所述多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。
第二方面,本申请实施例提供了一种图像分类方法,该方法包括:获取待处理图像;利用分类模型中的图像特征提取网络对待处理图像进行特征提取,得到目标图像特征;利用分类模型中的提示生成网络基于所述目标图像特征生成与多个提示文本分别对应的目标提示特征,所述多个提示文本用于从多个角度描述所述待处理图像中的对象;利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的目标提示特征,生成所述多个提示文本在所述多个维度中每个维度下的目标融合特征;基于目标图像特征和所述多个提示文本在所述多个维度中每个维度下的目标融合特征,确定待处理图像在每个维度下的类别。
第三方面,本申请实施例提供了一种分类模型训练装置,该装置包括数据获取模块、第一图像特征提取模块、第一提示特征生成模块、第一融合特征生成模块、损失确定模块以及模型训练模块。数据获取模块用于获取训练数据,所述训练数据包括多张样本图像,所述样本图像在多个维度中至少一个维度下具有样本类别标签;第一图像特征提取模块,用于利用分类模型中的图像特征提取网络对所述样本图像进行特征提取,得到样本图像特征;第一提示特征生成模块,用于利用分类模型中的提示生成网络基于所述样本图像特征,生成与多个提示文本分别对应的样本提示特征;第一融合特征生成模块,用于利用分类模型中的文本特征提取网络基于所述多个提示文本分别对应的参考提示特征和所述多个提示文本分别对应的样本提示特征,生成所述多个提示文本在所述多个维度中每个维度下的样本融合特征;损失确定模块,用于基于所述多个提示文本在所述多个维度中每个维度下的所述样本融合特征、所述样本图像特征以及所述样本图像在至少一个维度下的样本类别标签,确定第一模型损失;模型训练模块,用于基于所述第一模型损失调整所述分类模型的模型参数以及所述多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。
在一种可实施方式中,第一融合特征生成模块包括融合子模块以及特征提取子模块,融合子模块用于将每个所述提示文本的参考提示特征与所述提示文本对应的样本提示特征进行融合,得到各所述提示文本对应的样本融合提示特征;特征提取子模块,用于利用分类模型的文本特征提取网络对各所述提示文本对应的样本融合提示特征进行特征提取,得到各提示文本在所述多个维度中每个维度下的样本融合特征。
在一种可实施方式中,损失确定模块,还用于基于相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度,确定特征互斥损失;模型训练模块,还用于基于所述特征互斥损失调整所述提示文本。
在一种可实施方式中,损失确定模块还用于计算相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度;对所述相似度进行求和,得到特征互斥损失。
在一种可实施方式中,损失确定模块包括相似度计算子模块、类别确定子模块以及损失确定子模块。相似度计算子模块用于计算各所述提示文本在每个维度下的样本融合特征与所述样本图像特征之间的特征相似度;类别确定子模块,用于基于同一维度下各提示文本对应的样本融合特征与所述样本图像特征之间的特征相似度,确定样本图像在每个维度下的预测类别;损失确定子模块,用于基于所述样本图像在至少一个维度下的样本类别标签以及样本图像在每个维度下的预测类别,确定第一模型损失。
在一种可实施方式中,损失计算子模块,还用于基于样本图像在至少一个维度下的类别标签以及样本图像在每个维度下的预测类别进行交叉熵损失计算,得到第一模型损失。
在一种可实施方式下,模型训练模块还用于基于所述第一模型损失调整所述提示生成网络的模型参数以及提示文本。
在一种可实施方式中,所述装置还包括:图文对获取模块、对比学习处理模块以及网络更新模块。图文对获取模块,用于获取多个图像文本对,所述图像文本对包括图像样本和文本样本;对比学习处理模块,用于将所述多个文本样本输入至文本特征提取网络,并将所述多个图像样本输入至图像特征提取网络,进行对比学习处理,得到属于同一图像文本对的图像样本和文本样本之间的相似度,以及属于不同图像文本对的图像样本和文本样本之间的相似度;网络更新模块,用于基于属于同一图像文本对的图像样本和文本样本之间的相似度,以及属于不同图像样本对的图像样本和文本样本之间的相似度获得第二模型损失;基于所述第二模型损失更新所述文本特征提取网络和所述图像特征提取网络,直至达到预训练结束条件。
第四方面,本申请实施例提供了一种图像分类装置,该装置包括图像获取模块、第二图像特征提取模块、第二提示特征生成模块、第二融合特征生成模块以及类别确定模块。图像获取模块,用于获取待处理图像;第二图像特征提取模块,用于利用分类模型中的图像特征提取网络对待处理图像进行特征提取,得到目标图像特征;第二提示特征生成模块,用于利用分类模型中的提示生成网络基于所述目标图像特征生成与多个提示文本分别对应的目标提示特征,所述多个提示文本用于从多个角度描述所述待处理图像中的对象;第二融合特征生成模块,用于利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的目标提示特征,生成所述多个提示文本在所述多个维度中每个维度下的目标融合特征;类别确定模块,用于类别确定模块,用于基于目标图像特征和所述多个提示文本在所述多个维度中每个维度下的目标融合特征,确定待处理图像在每个维度下的类别。
在一种可实施方式中,类别确定模块还用于针对每个维度,基于所述目标图像特征和各提示文本在该维度下的融合特征,确定该维度下各提示文本的类别预测结果;对该维度下的各提示文本的类别预测结果进行加权求和,得到待处理图像在该维度下的类别。。
在一种可实施方式中,类别确定模块还用于计算每个目标融合特征与目标图像特征之间的相似度;针对每个维度下的每个提示文本对应的目标融合特征与目标图像特征之间的相似度,确定每个维度下与每个提示文本对应的类别预测结果。
第五方面,本申请实施例提供了一种电子设备,包括处理器以及存储器;一个或多个程序被存储在所述存储器中并被配置为由所述处理器执行以实现上述的方法。
第六方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有程序代码,其中,在所述程序代码被处理器运行时执行上述的方法。
第七方面,本申请实施例提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质获取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述的方法。
本申请实施例提供的一种分类模型训练方法、分类方法、装置及电子设备。方法包括:利用分类模型中的图像特征提取网络对样本图像进行特征提取,得到样本图像特征;利用分类模型中的提示生成网络基于样本图像特征,生成与多个提示文本分别对应的样本提示特征;利用分类模型中的文本特征提取网络基于每个提示文本的参考提示特征和各提示文本对应的样本提示特征,生成各提示文本在每个维度下的样本融合特征;基于各提示文本在每个维度下的样本融合特征、样本图像特征以及样本图像在至少一个维度下的样本类别标签,确定第一模型损失;基于所述第一模型损失调整所述分类模型的模型参数以及所述多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。采用上述方法,在分类模型训练中,通过加入可学习的提示文本,并在模型训练阶段,通过利用提示生成网络和文本特征提取网络基于样本图像特征学习到同一维度下与多个提示文本分别对应的具有差异性的样本融合特征,使得同一维度对应的多个样本融合特征可以从不同角度去描述同一对象,实现了对样本图像利用多个样本融合特征进行更准确且更全面描述,进而在利用样本图像特征、多个具有差异性的样本融合特征以及样本图像在至少一个维度下的样本类别标签进行损失计算并调整分类模型时可以有效提升训练后的模型的准确度,以及使分类模型能够用于对图像进行多维度分类。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出了本申请实施例提供的一种分类模型训练方法的应用场景图;
图2示出了本申请实施例提出的一种分类模型训练方法的流程示意图;
图3示出了本申请实施例提出的一种分类模型的提示生成网络的示意图;
图4示出了本申请实施例提出的一种分类模型的模型结构示意图;
图5示出了本申请实施例提出的一种训练样本的获取流程框图;
图6示出了本申请实施例提出的一种分类模型训练方法的另一流程框图;
图7示出了本申请实施例提出的一种图像分类方法的流程示意图;
图8示出了本申请实施例提出的一种分类模型的分类示意图;
图9示出了本申请实施例剔除的一种分类模型的应用场景图;
图10示出了本申请实施例提出的一种分类模型训练装置的连接框图;
图11示出了本申请实施例提出的一种图像分类装置的连接框图;
图12示出了用于执行本申请实施例的方法的电子设备的结构框图。
具体实施方式
现在将参考附图更全面地描述示例实施方式。然而,示例实施方式能够以多种形式实施,且不应被理解为限于在此阐述的参考示例;相反,提供这些实施方式使得本申请将更加全面和完整,并将示例实施方式的构思全面地传达给本领域的技术人员。
此外,所描述的特征、结构或特性可以以任何合适的方式结合在一个或更多实施例中。在下面的描述中,提供许多具体细节从而给出对本申请的实施例的充分理解。然而,本领域技术人员将意识到,可以实践本申请的技术方案而没有特定细节中的一个或更多,或者可以采用其它的方法、组元、装置、步骤等。在其它情况下,不详细示出或描述公知方法、装置、实现或者操作以避免模糊本申请的各方面。
附图中所示的方框图仅仅是功能实体,不一定必须与物理上独立的实体相对应。即,可以采用软件形式来实现这些功能实体,或在一个或多个硬件模块或集成电路中实现这些功能实体,或在不同网络和/或处理器装置和/或微控制器装置中实现这些功能实体。
附图中所示的流程图仅是示例性说明,不是必须包括所有的内容和操作/步骤,也不是必须按所描述的顺序执行。例如,有的操作/步骤还可以分解,而有的操作/步骤可以合并或部分合并,因此实际执行的顺序有可能根据实际情况改变。
需要说明的是:在本文中提及的“多个”是指两个或两个以上。“和/或”描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。
随着人工智能技术研究和进步,人工智能技术在多个领域展开研究和应用,并发挥越来越重要的价值。
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式作出反应的智能机器。以人工智能应用在机器学习上为例进行说明:
其中,机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。本申请的方案主要是利用机器学习进行图像分类。
图1是根据本申请一实施例示出的应用场景的示意图,如图1所示,该应用场景包括终端设备10和通过网络与终端设备10通信连接的服务器20。
终端设备10,终端设备10具体可以是手机、电脑、智能语音交互设备、智能家电、车载终端等,终端设备10可以设有用于展示数据的客户端。网络可以是广域网或者局域网,或者是二者的组合。
服务器20可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、CDN、以及大数据和人工智能平台等基础云计算服务的云服务器。
若利用如图1中的终端设备10和服务器20进行分类模型训练,终端设备10可以向服务器20上传训练数据,服务器20在获取到训练数据后,利用分类模型中的图像特征提取网络对样本图像进行特征提取,得到样本图像特征;利用分类模型中的提示生成网络基于样本图像特征,生成与多个提示文本分别对应的样本提示特征;利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的样本提示特征,生成多个提示文本在多个维度中每个维度下的样本融合特征;基于多个提示文本在多个维度中每个维度下的样本融合特征、样本图像特征以及样本图像在至少一个维度下的样本类别标签,确定第一模型损失;基于第一模型损失调整分类模型的模型参数以及多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。
通过采用上述方法,在分类模型训练中,通过加入可学习的提示文本,并在模型训练阶段,通过利用提示生成网络和文本特征提取网络基于样本图像特征学习到同一维度下与多个提示文本分别对应的具有差异性的样本融合特征,使得多个样本融合特征可以从不同角度去描述同一对象,实现了通过多个样本融合特征对样本图像进行更准确且更全面描述,进而在利用样本图像特征、多个具有差异性的样本融合特征以及样本图像在至少一个维度下的样本类别标签进行损失计算并调整分类模型时可以有效提升训练后的模型的准确度,以及使分类模型能够用于对图像进行多维度分类。
在分类模型训练完成后,可以在目标服务器上部署训练后的分类模型,其中,目标服务器可以是即时通讯服务器、视频播放服务器或者内容交互服务器等。目标服务器可以利用分类模型对其上存储的图像数据进行分类处理,并在分类处理结果为异常时,执行告警及删除等操作,还可以接收终端设备发送的待处理图像,以及利用训练后的分类模型对待处理图像进行分类,得到待处理图像在每个维度下的类别并向终端设备发送包括每个维度下的类别的分类结果。
下面将结合附图具体描述本申请的各实施例。
请参阅图2,图2所示为本申请还提供一种分类模型训练方法,可以应用于电子设备,该电子设备可以是上述的终端设备10或服务器20,该方法包括:
步骤S110:获取训练数据。
训练数据包括多张样本图像,样本图像在多个维度中至少一个维度下具有样本类别标签。
获取训练数据的方式可以是,从网站上爬取多张图像并进行分类标注后作为训练数据,也可以是从电子设备或与该电子设备关联的其他设备中获取预先存储的多张具有样本类别标签的图像作为训练数据,根据实际需求进行设置即可。
多张样本图像的尺寸大小可以相同也可以不同,若多张样本图像的尺寸大小不同,则可以将多张样本图像分别进行缩放处理以得到尺寸大小相同的多张样本图像。
其中,一个维度可以理解为图像分类的一种标准。维度可以与分类任务相关,具体地,多个维度可以是同一分类任务下图像分类的多种标准,即分类模型用于分类任务,该多个维度可根据具体的分类任务设定。例如,若分类任务是识别图像是否合格,图像是否合格可以从是否存在二维码、是否涉及存在文字广告、是否涉及不良场景(例如抽烟等)来确定,例如,若设定若图像中存在二维码、存在水印、存在文字广告、存在抽烟场景,则确定该图像为不合格的图像,那么,在此种情况下,是否存在二维码、是否涉及存在文字广告、是否涉及抽烟场景分别视为分类任务中的一个维度。
示例性的,若分类模型是用于对图像中的对象(如,动植物)进行分类,如用于对猫、狗、猪等动物进行分类,则样本图像的至少一个维度下的分类标签可以是:样本图像为猫这一维度下的分类标签,样本图像为狗这一维度下的分类标签,以及样本图像为猪这一维度下的分类标签等中的一种或多种,其中,上述一个维度下的分类标签为该维度下的二分类的分类标签(如,是或否的标签)中的一个。若分类模型是用于对图像是否为异常图像,以及在图像为异常图像时,具体获得异常类型,如第一类异常、第二类异常以及第三类异常等时,则样本图像为每个维度下的类别标签可以包括样本图像为正常图像这一维度下的分类标签以及样本图像为异常图像这一维度下的分类标签中的至少一种,其中,样本图像为正常图像这一维度下的分类标为二分类的分类标签(如,是或否的分类标签),样本图像为异常图像的这一维度下的分类标签为多分类的分类标签(如,为第一类异常、第二类异常或者第三类异常等的分类标签)。
步骤S120:利用分类模型中的图像特征提取网络对样本图像进行特征提取,得到样本图像特征。
其中,上述的分类模型中的图像特征提取网络可以是神经网络。具体的,上述的神经网络可以是ResNet残差网络、DenseNet经典网络、VGG卷积神经网络、AlexNet深度卷积神经网络、Swin-Transformer网络、MaxViT网络或者LeNet卷积神经网络等等任意可以进行图像特征提取的神经网络。
在本申请的一种可实施方式中,上述的分类模型中的图像特征提取网络可以是预训练的CLIP模型(Contrastive Language-Image Pre-Training,对比语言-图像预训练模型)中的图像特征提取网络。
其中,CLIP模型主要包含Text Encoder网络(文本特征提取网络)和ImageEncoder网络(图像特征提取网络),分别提取文本特征和图像特征,然后对提取的文本特征和图像特征进行比对学习让模型学习到文本-图像的匹配关系。在CLIP模型的预训练阶段,可以使用大规模训练样本进行训练,从而基于海量的训练样本使预训练的CLIP模型可以学习到更多通用的视觉语义信息,给下游任务(如,图文检索、文本视频检索、图文问答、图文生成等图像和/或文本处理任务)提高帮助。在本实施例中,主要将预训练CLIP模型的图像特征提取网络用作分类模型中的图像特征提取网络,还可以将预训练CLIP模型的文本特征提取网络用作分类模型中的文本特征提取网络。
步骤S130:利用分类模型中的提示生成网络基于样本图像特征,生成与多个提示文本分别对应的样本提示特征。
其中,提示文本是指用于辅助描述样本图像中的对象的提示信息,且提示文本为可训练的文本。示例性的,在初始情况下,多个提示文本可以包括“a photo of xxx”、“Thispicture is used to describe XX”以及“There are XX in this picture”等。之后模型的训练过程,可以不断调整提示文本,使得调整后的提示文本可以从多个不同描述角度描述样本图像中的对象,示例性的,以分类模型用于为动物进行分类为例,则多个提示信息可以包括动物物种的学术角度的提示信息,人们喜好的宠物角度的提示信息,以及描述动物的具体形态特征角度的提示信息等中的多种。
其中,上述的提示生成网络可以是任意能够学习输入和输出之间的映射关系的网络,如可以是MLP(Multilayer perceptron,多层感知机)网络,也可以是全连接网络或者Transformer网络等。
在本申请的一种可实施方式中,提示生成网络可以是MLP网络,上述的MLP网络的层数可以为多层,如至少3层。且该MLP网络用于学习输入的样本图像特征与多个样本提示特征之间的关系。
如图3所示,示出了提示生成网络为MLP网络示意图,MLP网络包括输入层、隐藏层以及输出层,在利用该MLP网络基于样本图像特征生成与多个提示文本分别对应的样本提示特征时,输入层用于接收样本图像特征。在后续的数学表述中,可以将输入的样本图像特征记为X0,后续的隐藏层从1开始计数,第i个隐藏层对应的输入为Xi-1。隐藏层是MLP网络的核心,每一个隐藏层可以将其拆分为两部分:全连接和激活函数。全连接层可以视为用一个权重矩阵Wi与输入Xi-1相乘,再添加一个偏置项bi,这个偏置项可以直接放入权重矩阵Wi中进行训练。激活函数Activation Function:激活函数起非线性映射的作用,其可将神经元的输出幅度限制在一定范围内,一般限制在(-1~1)或(0~1)之间。输出层用于依据激活函数的计算来输出最后的结果。在本申请实施例中,输出层输出的结果为多个(如,a1,a2,...,aN),且每个输出结果与一个提示文本对应。即,在本申请中,提示生成网络输出的样本提示特征的数量与初始给定的提示文本的数量是相同的。
在本申请中,各提示文本中并没有明确描述的对象,例如上文中“a photo ofxxx”,该提示文本中的“xxx”是以掩码的形式表示的,通过提示生成网络来基于样本图像特征来生成与各提示文本相适配的样本提示特征,与一提示文本相适配的样本提示特征可以与提示文本对应的特征(例如下文中提示文本对应的参考提示特征)相组合,作为对样本图像中对象的对象描述特征。与一提示文本相适配的样本提示特征可以是提示生成网络基于样本图像特征所学习到样本图像中对象的特征。
步骤S140:利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的样本提示特征,生成多个提示文本在多个维度中每个维度下的样本融合特征。
在一些实施例中,提示文本的参考提示特征可以是对提示文本中全部词的词向量进行拼接得到,其中,提示文本中各词的词向量可以从词典中获得。在一些实施例中,提示文本的参考提示特征可以是对提示文本进行语义编码得到的特征,其中,语义编码阶段是自然语言处理中的一个重要环节,其主要任务是将自然语言文本转换为计算机能够理解或处理的形式。在对提示文本进行语义编码时,可以根据从语料库中包括词语与对应的语义编码结果中获取与提示文本中的各词语对应的语义特征。其中,上述的文本特征提取网络可以是神经网络。具体的,上述的神经网络可以是word2vec网络、卷积神经网络、循环神经网络、或者Transformer网络等等任意可以进行文本特征提取的神经网络。
在本申请的一种可实施方式中,上述的分类模型中的图像特征提取网络可以是预训练的CLIP模型(Contrast ive Language-Image Pre-Train ing,对比语言-图像预训练模型)中的文本特征提取网络。
其中,利用分类模型的文本特征提取网络生成多个维度中每个维度下的样本融合特征时,可以先对每个提示文本对应的样本提示特征和该提示文本对应的参考提示特征进行融合,得到该提示文本对应的样本融合提示特征,之后利用分类模型的文本特征提取网络基于样本融合提示特征生成各提示文本在多个维度下的样本融合特征;还可以直接利用文本特征提取网络基于每个提示文本对应的样本提示特征和参考提示特征生成每个提示文本在多个维度下的样本融合特征。
在本申请的一种可实施方式中,上述步骤S140具体可以是,将每个提示文本的参考提示特征与提示文本对应的样本提示特征进行融合,得到各提示文本对应的样本融合提示特征。利用分类模型的文本特征提取网络对各提示文本对应的样本融合提示特征进行特征提取,得到各提示文本在多个维度中每个维度下的样本融合特征。
其中,将每个提示文本的参考提示特征与提示文本对应的样本提示特征进行融合的方式可以是,针对每个提示文本,将该提示文本的参考提示特征与该提示文本对应的样本提示特征进行对位相乘、对位相加或者拼接等方式,得到各提示文本对应的样本融合提示特征。
其中,上述的多个维度可以是在分类模型进行模型训练之前确定,也可以是在分类模型的模型训练阶段基于输入的训练样本中样本标签的维度确定。
步骤S150:基于多个提示文本在多个维度中每个维度下的样本融合特征、样本图像特征以及样本图像在至少一个维度下的样本类别标签,确定第一模型损失。
在一种可实施方式中,上述步骤S150具体可以是基于多个提示文本在多个维度中每个维度下的样本融合特征和样本图像特征确定样本图像在每个维度下的预测类别,基于样本图像在每个维度下的预测类别和样本图像在至少一个维度下的类别标签进行损失计算,得到第一模型损失。
在该种可实施方式下,可以基于各提示文本在每个维度下的样本融合特征和样本图像特征确定每个维度下的预测类别的方式可以是,计算各样本融合特征与样本图像特征之间的相似度,基于同一维度下各提示文本对应的样本融合特征相似度,确定样本图像在每个维度下的预测类别;基于样本图像在至少一个维度下的样本类别标签以及样本图像在每个维度下的预测类别,确定第一模型损失。
其中,计算样本融合特征与样本图像特征之间的相似度的方式可以是,计算样本融合特征与样本图像之间的余弦相似度或欧式距离等。在确定第一模型损失时,可以利用预设损失函数基于样本图像在至少一个维度下的样本类别标签以及样本图像在每个维度下的预测类别进行损失计算,得到第一模型损失。其中,预设损失函数可以是交叉熵损失函数、均方误差损失函数或者多分类交叉熵损失函数等,根据实际需求进行设置即可。
在另一种可实施方式中,上述步骤S150还可以是,对样本融合特征、样本融合特征的转置、样本图像特征以及样本图像特征的转置分别进行正则化处理,计算每个维度下的正则化处理后的各样本融合特征与正则化处理后的样本图像特征的转置之间的第一相似度,以及计算每个维度下的正则化处理后的各样本融合特征的转置与样本图像特征之间的第二相似度,基于每个维度下样本融合特征对应的第一相似度和第二相似度以及样本图像在至少一个维度下的样本类别标签计算模型损失。
示例性的,在进行损失计算时,可以利用如下损失计算公式进行计算,损失计算公式如下:
其中,Ienc为样本图像特征,Tenc为样本融合特征,norm为L2正则化,其中,L2正则化的目的是减少分类模型训练过程中过拟合问题,t为温度系数,e为自然常数,norm(Ienc).norm(Tenc)T表征第一相似度,(norm(Ienc)T.norm(Tenc)表征第二相似度,CrossEntropyLoss为交叉熵损失函数。需要说明的是,以上损失计算公式中,L2正则化的结果是通过样本图像在至少一个维度下的样本类别标签以及如上的样本图像特征和样本融合特征确定的,其可以仅用于计算基于样本图像在某一维度下的一个样本融合特征对应的分类预测结果与样本图像在至少一个维度下的样本标签之间的损失,从而基于每个维度下的每个样本融合特征对应的损失得到模型损失。
应当理解,上述确定第一模型损失的方式仅为示意性的,还可以有更多的确定方式,在本实施例不再一一赘述。
请参阅图4,若利用图像特征提取网络对样本图像提取到的样本图像特征为(I1,I2,I3,...,In),其中,n表示特征维度,且多个提示文本对应的参考提示特征为Prompt 1、Prompt 2以及Prompt 3,则提示生成网络生成的与多个提示文本分别的对应的样本提示特征为a1、a2以及a3,利用文本特征提取网络获得的各提示文本在多个维度下的样本融合特征为(T1,T2,T3,...,Tn),其中,Tn表示该提示文本在第n个维度下的样本融合特征,通过计算基于同一维度下各提示文本对应的样本融合特征相似度其中,IiTj标识第i维度的样本图像特征与第j维度的样本文本特征之间的相似度,,确定样本图像在每个维度下的预测类别,以及根据样本图像在至少一个维度下的样本类别标签以及样本图像在每个维度下的预测类别,确定第一模型损失。
步骤S160:基于第一模型损失调整分类模型的模型参数以及多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。
应当理解,当训练次数达到第一预设次数或者模型损失小于第一预设损失阈值时,则认为对分类模型的训练达到训练结束条件,且上述达到训练结束条件的分类模型即可用于执行后续的图像分类。其中,第一预设次数和第一预设损失阈值可以根据任务需求进行设置,此处不作具体限定。
其中,上述调整分类模型的模型参数可以是调整分类模型中的图像特征提取网络、提示生成网络以及文本特征提取网络中的至少一个网络的模型参数。
在本申请的一种可实施方式中,若在步骤S120之前对分类模型中的图像特征提取网络和文本特征提取网络进行了预训练,在上述步骤S160可以是,基于第一模型损失调整分类模型中提示生成网络的模型参数以及提示文本,即预训练后,在如上的训练过程中,固定图像特征提取网络和文本特征提取网络的参数,调整提示生成网络的模型参数和调整提示文本,以使提升生成网络可以自动学习和挖掘与当前图像和任务适配的多个样本提示特征,提升所生成的样本提示特征的适配性和多样性,而且,由于样本提示特征最后是与对应提示文本的参考提示特征进行融合,这样,可以提升融合效果。
在一些实施例中,在步骤S120之前,可按照如下的步骤S170~步骤S190所示的过程对图像特征提取网络和文本特征提取网络进行预训练:
步骤S170:获取多个图像文本对。
图像文本对包括图像样本和文本样本。
其中,一图像文本对中的文本样本是对该图像文本对中的图像样本的文本描述。
其中,获取多个图像文本对的方式可以是获取数据库或电子设备中存储的图像文本对,也可以是获取互联网上存在的大量图像文本对,图像文本对还可以是利用文生图模型基于样本文本生成对应的样本图像,或者利用图生文模型基于图像文本生成对应的样本文本。
其中,若图像文本对还可以是利用文生图模型基于样本文本生成对应的样本图像,或者利用图生文模型基于图像文本生成对应的样本文本,则文生图模型或者图生文模型可以是基于半监督训练得到,也可以是基于无监督训练得到。
在半监督场景下,如图5所示,文生图模型或图生图模型可以是使用有标注数据训练初版模型,然后用初版模型对无标注数据进行预测打标签得到为伪标注数据,最后利用伪标注数据和有标注数据共同进行训练得到。
步骤S180:将多个文本样本输入至文本特征提取网络,并将多个图像样本输入至图像特征提取网络,进行对比学习处理,得到属于同一图像文本对的图像样本和文本样本之间的相似度,以及属于不同图像文本对的图像样本和文本样本之间的相似度。
通过将文本样本输入到文本特征提取网络可以提取到文本样本的特征,将图像文本输入到图像特征提取网络可以提取到图像样本的特征,通过将图像文本的特征与文本样本的特征进行对比学习,可以获得多个文本样本中任意一个文本样本与多个图像样本中的每个图像样本之间的相似度,从而获知属于不同图像文本对的图像样本和文本样本之间的相似度,以及属于同一图像文本对的图像样本和文本样本之间的相似度。
步骤S190:基于属于同一图像文本对的图像样本和文本样本之间的相似度,以及属于不同图像样本对的图像样本和文本样本之间的相似度获得第二模型损失,并基于第二模型损失更新文本特征提取网络和图像特征提取网络,以使属于同一图像样本对的图像样本和文本样本之间的相似度增大,以及使属于不同图像文本对的图像样本与文本样本之间的特征相似度减小,直至达到预训练结束条件。
在预训练过程中,基于图像文本对中的图像样本和文本样本来对图像特征提取网络和文本特征提取网络,使图像特征提取网络和文本特征提取网络学习到同一图像文本对中的图像样本和文本样本之间的匹配关系,并且,可以学习到来自不同图像文本对的图像样本和文本样本之间的差异。因此,在预训练过程中,属于同一图像文本对的图像样本和文本样本之间的相似度越高,则对第二损失的影响越小,反之,属于同一图像文本对的图像样本和文本样本之间的相似度越低,则对第二损失的影响越大。属于不同图像文本对的图像样本与文本样本之间的相似度越高,则对第二损失的影响越大,反之属于于不同图像文本对的图像样本与文本样本之间的相似度越低,则对第二损失的影响越小。
其中,当训练次数达到第二预设次数或者模型损失小于第二预设损失阈值时,则认为对分类模型的训练达到预训练结束条件,其中,第二预设次数和第二预设损失阈值可以根据实际需求进行设置,此处不作具体限定。
通过采用本申请的上述模型训练方法,在分类模型训练中,通过加入可学习的提示文本,并在模型训练阶段,通过利用提示生成网络和文本特征提取网络基于样本图像特征学习到同一维度下与多个提示文本分别对应的具有差异性的样本融合特征,使得同一维度对应的多个样本融合特征可以从不同角度去描述同一对象,实现了对样本图像利用多个样本融合特征进行更准确且更全面描述,进而在利用样本图像特征、多个具有差异性的样本融合特征以及样本图像在至少一个维度下的样本类别标签进行损失计算并调整分类模型时可以有效提升训练后的模型的准确度。
此外,针对在内容审核应用场景,如图像审核分类任务多,标签数多,业务场景复杂,但是运营时效性要求高。直接使用单模型进行分类,效果难以应对复杂多变的应用场景,同时难以达到快速优化迭代的运营要求。本申请实施例针对此种现象,提出一种基于提示学习的高效参数学习的多模型集成方法(即,上述的模型训练方法)使训练得到的模型不仅可以得到应对复杂多变业务场景,同时由于训练过程中仅调整提示生成网络的参数和调整提示文本,使得训练代价极小,同时能够实现同时执行多分类任务,使得该模型可以应该用在复杂的应用场景下以达到快速运营的效果。
进一步的,可以先对分类模型中的图像特征提取网络和文本特征提取网络进行预训练,这样,在将分类模型进行下游任务的训练过程中,可以不需要调整图像特征提取网络和文本特征提取网络的参数,仅调整分类模型中的提示生成网络及提示文本,使得本申请的分类模型中需要学习的参数数量较少,训练分类模型过程中所需的代价较小,从而提升了模型训练的效率。
请参阅图6所示,本申请实施例还提供一种分类模型训练方法,该方法包括:
步骤S210:获取训练数据。
其中,训练数据包括多张样本图像,样本图像在多个维度中至少一个维度下具有样本类别标签。
步骤S220:利用分类模型中的图像特征提取网络对样本图像进行特征提取,得到样本图像特征。
步骤S230:利用分类模型中的提示生成网络基于样本图像特征,生成与多个提示文本分别对应的样本提示特征。
步骤S240:利用分类模型的文本特征提取网络对各提示文本对应的样本融合提示特征进行特征提取,得到各提示文本在多个维度中每个维度下的样本融合特征。
步骤S250:基于多个提示文本在多个维度中每个维度下的样本融合特征、样本图像特征以及样本图像在至少一个维度下的样本类别标签,确定第一模型损失。
关于上述步骤S210-S250具体描述可以参阅前文对步骤S110-S150的具体描述,在本实施例不再一一赘述。
步骤S260:计算相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度。
其中,上述计算相似度的方式可以是计算相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的余弦相似度或者欧式距离等。
步骤S270:基于相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度,确定特征互斥损失。
其中,特征互斥损失可以是基于全部相似度的和确定,也可以是基于全部相似度的平均值确定,还可以是基于全部相似度的最大值确定。
在一种可实施方式中,上述步骤S270具体可以是,计算相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度;对相似度进行求和,得到特征互斥损失。
示例性的,特征互斥损失还可以利用如下公式计算得到:式中,A和B表示相同维度下的一组样本融合特征中的两个样本融合特征,n表示维度的数量。
在获得特征互斥损失之后,可以利用特征互斥损失调整提示文本,从而降低后续获得的样本融合特征之间的相似性,提升样本融合特征之间的正交性(即差异性),从而增强提示生成网络对一个样本图像中的对象从不同的角度进行描述的能力,比如对于“猫”这个维度,可以从动物物种的学术角度,也可以从生活大家喜好的宠物角度,或者描述“猫”的具体形态特征的角度去提示,这些正交的角度可以形成对“猫”这个类别的互补性描述,从而达到提升集成子网络结果之间的互补性的目的。
步骤S280:基于特征互斥损失调整提示文本,以及基于第一模型损失调整分类模型的模型参数以及多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。
应当理解,提升互斥损失还可以用作调整分类模型的模型参数,也即上述步骤S280可以是,基于第一模型损失以及特征互斥损失调整分类模型的模型参数以及提示文本,直至达到训练结束条件。
通过采用上述方法,在分类模型训练中,通过加入可学习的提示文本,并在模型训练阶段,利用提示生成网络和文本特征提取网络基于样本图像特征学习到同一维度下与多个提示文本分别对应的具有差异性的样本融合特征,使得同一维度对应的多个样本融合特征可以从不同角度去描述同一对象,此外,在学习阶段无需显式地调整样本的权重而是通过学习到有差异性和互补的样本融合特征,实现了对样本图像利用多个样本融合特征进行更准确且更全面描述,进而在利用样本图像特征、多个具有差异性互补性的样本融合特征以及样本图像在至少一个维度下的样本类别标签进行损失计算并调整分类模型时可以有效提升训练后的模型的准确度。
请参阅图7所示,本申请实施例还提供了一种图像分类方法,该方法包括:
步骤S310:获取待处理图像。
其中,待处理图像可以是任意需要进行处理的图像,如网络中传输的图像,视频中的图像,或者图像中的局部图像等。
相应的,上述步骤S310可以是,获取网络中传输的图像,如即时通讯服务器中传输的图像或任意网站中的图像等;上述步骤S310还可以是,获取视频中的图像,在该种方式下,上述步骤S310具体可以是对目标视频进行抽帧得到待处理图像;上述步骤S310还可以,获取目标图像中的局部图像,目标图像可以是多张图像拼接而成的图像中的一张图像,在该种方式下,上述步骤S310具体可以是对目标图像进行切分得到待处理图像。
步骤S320:利用分类模型中的图像特征提取网络对待处理图像进行特征提取,得到目标图像特征。
其中,分类模型可以利用前述实施例中的分类模型训练方法训练得到。关于分类模型的具体训练过程可以参阅前述实施例中的具体描述,在本申请实施例不再一一赘述。
步骤S330:利用分类模型中的提示生成网络基于目标图像特征生成与多个训练后的提示文本分别对应的目标提示特征。
其中,多个提示文本用于对待处理图像中的对象采用多个描述角度进行描述的文本。
步骤S340:利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的目标提示特征,生成多个提示文本在多个维度中每个维度下的目标融合特征。
其中,上述步骤S320-S340的处理过程与前述实施例中的步骤S120-S140的处理过程类似,因此,关于步骤S320-S340的具体处理过程可以参阅前文对步骤S120-S140的具体描述,在本实施例不在一一赘述。
步骤S350:基于目标图像特征和多个提示文本在多个维度中每个维度下的目标融合特征,确定待处理图像在每个维度下的类别。
上述步骤S350可以是,针对每个维度,基于目标图像特征和各提示文本在该维度下的融合特征,确定该维度下各提示文本的类别预测结果;对该维度下的各提示文本的类别预测结果进行加权求和,得到待处理图像在该维度下的类别。
具体的,针对每个维度,基于目标图像特征和各提示文本在该维度下的融合特征,确定该维度下各提示文本的类别预测结果时,可以计算每个目标融合特征与目标图像特征之间的相似度;针对每个维度下的每个提示文本对应的目标融合特征与目标图像特征之间的相似度,确定每个维度下与每个提示文本对应的类别预测结果。
其中,计算每个目标融合特征与目标图像特征之间的相似度方式可以是,计算每个目标融合特征与目标图像特征之间欧式距离或者余弦相似度等。其具体计算过程应当与前述实施例中计算样本融合特征与样本图像特征之间的相似度的计算过程相类似,因此,关于上述步骤S350的具体描述可以参阅前文对步骤S150中计算相似度的过程,此处不再一一赘述。
其中,上述步骤S350还可以是,基于同一维度对应的多个目标融合特征各自对应的相似度,确定每个维度分别对应的目标相似度;基于每个维度对应的目标相似度得到每个维度下的类别。
其中,目标相似度可以是同一维度对应的多个目标融合特征各自对应的相似度的均值或中值等。
示例性地,可以将每个目标融合特征对应的维度以及每个目标融合特征对应的相似度
的平均值作为目标相似度:
式中,Y表示待处理图像在每个维度分别对应的目标相似度,N表示提示文本的数量,norm表示正则化,Ienc为样本图像特征,Tenc为样本融合特征。
上述步骤S250还可以是针对每个维度,基于该维度对应的多个目标融合特征各自对应的相似度中的最大相似度或最小相似度,确定该维度下的类别。
采用本申请的图像分类方法,可以实现在分类过程中,利用对待处理图像中的对象采用多个描述角度进行描述的提示文本和图像特征获得目标融合特征,使得同一维度对应的多个目标融合特征可以从不同角度去描述待处理图像中的对象,以便后续在对待处理图像进行分类时,能够基于在利用待处理图像的目标图像特征、多个具有差异性的目标融合特征来实现对待处理图像进行分类预测,从而可以有效提升分类结果的准确性。
请结合参阅图8和图9所示,本申请实施例提供了一种分类模型训练方法,且以该训练方法的训练得到的分类模型用作对待处理图像进行是否异常的分类审核为例进行说明。
为了使模型能够分辨图像是正常还是违规(也即,异常,具体异常类型可能有很多种),需要准备大量的多分类标注数据,对于分类模型进行训练。其中,分类模型包括图像特征提取模块(即,前述的图像特征提取网络)、多提示学习模块(即,前述的提示生成网络)、文本特征提取模块(即,前述的文本特征提取网络)以及互斥模块(用于计算特征互斥损失的模块)。
图像特征提取模块用于提取图像特征,其具体为CLIP结构的图文预训练网络中Transformer结构的图像编码器。
多提示学习模块用于由一个简单三层MLP网络构成,用于生成与多个提示文本分别对应的样本提示特征,如a1、a2及a3
文本特征提取模块用于基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的样本提示特征,生成多个提示文本在多个维度中每个维度下的样本融合特征,具体的,可以针对每个提示文本,将该提示文本的参考提示特征与该提示文本对应的样本提示特征进行拼接后输入到文本特征提取模块中最终生成各提示文本在多个维度中每个维度下的样本融合特征。示例性的,若某一提示文本在一个维度下的样本融合特征为:Prompt1Final=Prompt1(V1,V 2,...Vm)+MLP(Ienc),其中,其中Prompt1为提示文本的参考提示特征,由M个词向量特征V1,V 2,...Vm组成,Ienc为样本提示特征。
文本特征互斥损失用于对对同一图像同一维度的不同提示文本对应的样本融合特征进行两两之间的相似度计算,从而基于计算得到的相似度得到特征互斥损失,通过使用互斥损失模块调整分类模型,可以降低同一图像同一维度的不同提示文本对应的样本融合特征之间的相似性,提升同一图像同一维度的不同提示文本对应的样本融合特征之间的正交性,相当于实现时对样本图像中的同一个对象从不同的角度进行描述,比如对于“猫”这个类别,可以从动物物种的学术角度,也可以从生活大家喜好的宠物角度,或者描述猫的具体形态特征的角度去提示,这些正交的角度可以形成对“猫”这个类别的互补性描述,从而达到提升文本特征提取网络的输出结果(相同维度下多个提示文本各自对应的样本融合特征)之间的互补性的目的,并最终达到提升分类模型的分类效果的目的。
在获得样本融合特征之后,可以基于多个提示文本在多个维度中每个维度下的样本融合特征、样本图像特征以及样本图像在至少一个维度下的样本类别标签,确定第一模型损失从而基于第一模型损失和特征互斥损失调整分类模型中多提示学习模块(提示生成网络)以及提示文本。
在利用前述训练得到的分类模型进行图像分类时,用户可以上传需要进行审核的视频、动图或者图像(如,长图,长图是指由多张图像拼接而成的图像),通过对视频或动图经过抽帧得到长图,将上述获得的长图或任意的长图切分处理后,得到待审核图像序列,该序列图像依次送入到训练后的分类模型中,并利用分类模型中的图像特征提取网络、提示生成网络、文本特征提取网络相互配合执行与前述模型训练相似的过程(具体参阅前述实施例中对待处理图像进行分类的过程),从而可以输出待处理图像的维度(正常或异常)以及在每个维度下的具体分类结果,其中,在正常这一维度下的分类结果为是或否,异常这一维度下的具体的分类结果可以包括第一异常类型、第二异常类型以及第三异常类型等中的一种或多种。
请参阅图10,本申请另一实施例提供了一种分类模型训练装置400,该分类模型训练装置400包括数据获取模块410、第一图像特征提取模块420、第一提示特征生成模块430、第一融合特征生成模块440、损失确定模块450以及模型训练模块460。数据获取模块410用于获取训练数据,训练数据包括多张样本图像,样本图像在多个维度中至少一个维度下具有样本类别标签;第一图像特征提取模块420,用于利用分类模型中的图像特征提取网络对样本图像进行特征提取,得到样本图像特征;第一提示特征生成模块430,用于利用分类模型中的提示生成网络基于样本图像特征,生成与多个提示文本分别对应的样本提示特征;第一融合特征生成模块440,用于利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的样本提示特征,生成多个提示文本在多个维度中每个维度下的样本融合特征;损失确定模块450,用于基于多个提示文本在多个维度中每个维度下的样本融合特征、样本图像特征以及样本图像在至少一个维度下的样本类别标签,确定第一模型损失;模型训练模块460,用于基于第一模型损失调整分类模型的模型参数以及多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。
在一种可实施方式中,第一融合特征生成模块440包括融合子模块以及特征提取子模块,融合子模块用于将每个提示文本的参考提示特征与提示文本对应的样本提示特征进行融合,得到各提示文本对应的样本融合提示特征;特征提取子模块,用于利用分类模型的文本特征提取网络对各提示文本对应的样本融合提示特征进行特征提取,得到各提示文本在多个维度中每个维度下的样本融合特征。
在一种可实施方式中,损失确定模块450,还用于基于相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度,确定特征互斥损失;模型训练模块,还用于基于特征互斥损失调整提示文本。
在一种可实施方式中,损失确定模块450还用于计算相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度;对相似度进行求和,得到特征互斥损失。
在一种可实施方式中,损失确定模块450包括相似度计算子模块、类别确定子模块以及损失确定子模块。相似度计算子模块用于计算各提示文本在每个维度下的样本融合特征与样本图像特征之间的特征相似度;类别确定子模块,用于基于同一维度下各提示文本对应的样本融合特征与样本图像特征之间的特征相似度,确定样本图像在每个维度下的预测类别;损失确定子模块,用于基于样本图像在至少一个维度下的样本类别标签以及样本图像在每个维度下的预测类别,确定第一模型损失。
在一种可实施方式中,损失计算子模块,还用于基于样本图像在至少一个维度下的类别标签以及样本图像在每个维度下的预测类别进行交叉熵损失计算,得到第一模型损失。
在一种可实施方式下,模型训练模块460还用于基于第一模型损失调整提示生成网络的模型参数以及提示文本。
在一种可实施方式中,分类模型训练装置400还包括:图文对获取模块、对比学习处理模块以及网络更新模块。图文对获取模块,用于获取多个图像文本对,图像文本对包括图像样本和文本样本;对比学习处理模块,用于将多个文本样本输入至文本特征提取网络,并将多个图像样本输入至图像特征提取网络,进行对比学习处理,得到属于同一图像文本对的图像样本和文本样本之间的相似度,以及属于不同图像文本对的图像样本和文本样本之间的相似度;网络更新模块,用于基于属于同一图像文本对的图像样本和文本样本之间的相似度,以及属于不同图像样本对的图像样本和文本样本之间的相似度获得第二模型损失,并基于第二模型损失更新文本特征提取网络和图像特征提取网络,以使属于同一图像样本对的图像样本和文本样本之间的相似度增大,以及使属于不同图像文本对的图像样本与文本样本之间的特征相似度减小,直至达到预训练结束条件。
请参阅图11所示,本申请实施例还提供一种图像分类装置500,该图像分类装置500包括图像获取模块510、第二图像特征提取模块520、第二提示特征生成模块530、第二融合特征生成模块540以及类别确定模块550。图像获取模块510,用于获取待处理图像;第二图像特征提取模块520,用于利用分类模型中的图像特征提取网络对待处理图像进行特征提取,得到目标图像特征;第二提示特征生成模块530,用于利用分类模型中的提示生成网络基于目标图像特征生成与多个提示文本分别对应的目标提示特征,多个提示文本用于从多个角度描述待处理图像中的对象;第二融合特征生成模块540,用于利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的目标提示特征,生成多个提示文本在多个维度中每个维度下的目标融合特征;类别确定模块550,用于基于目标图像特征和多个提示文本在多个维度中每个维度下的目标融合特征,确定待处理图像在每个维度下的类别。
在一种可实施方式中,类别确定模块550还用于针对每个维度,基于目标图像特征和各提示文本在该维度下的融合特征,确定该维度下各提示文本的类别预测结果;对该维度下的各提示文本的类别预测结果进行加权求和,得到待处理图像在该维度下的类别。
在一种可实施方式中,类别确定模块560还用于计算每个目标融合特征与目标图像特征之间的相似度;针对每个维度下的每个提示文本对应的目标融合特征与目标图像特征之间的相似度,确定每个维度下与每个提示文本对应的类别预测结果。
在一种可实施方式中,图像获取模块510,还用于对目标视频进行抽帧,得到待处理图像;或者对目标图像进行切分,得到待处理图像。
上述装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。需要说明的是,本申请中装置实施例与前述方法实施例是相互对应的,装置实施例中具体的原理可以参见前述方法实施例中的内容,此处不再赘述。
下面将结合图12对本申请提供的一种电子设备进行说明。
请参阅图12,本申请实施例还提供的另一种包括可以执行前述方法的处理器102的电子设备100,该电子设备100可以为服务器或终端设备,终端设备可以是智能手机、平板电脑、计算机或者便携式计算机等设备。该电子设备100可用于执行本申请提供的分类模型训练方法、或者执行图像分类方法。
电子设备100还包括存储器104。其中,该存储器104中存储有可以执行前述实施例中内容的程序,而处理器102可以执行该存储器104中存储的程序。
其中,处理器102可以包括一个或者多个用于处理数据的核以及消息矩阵单元。处理器102利用各种接口和线路连接整个电子设备100内的各个部分,通过运行或执行存储在存储器104内的指令、程序、代码集或指令集,以及调用存储在存储器104内的数据,执行电子设备100的各种功能和处理数据。可选地,处理器102可以采用数字信号处理(DigitalSignal Processing,DSP)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)、可编程逻辑阵列(Programmable Logic Array,PLA)中的至少一种硬件形式来实现。处理器102可集成中央处理器(Central Processing Unit,CPU)、图像处理器(GraphicsProcessing Unit,GPU)和调制解调器等中的一种或几种的组合。其中,CPU主要处理操作系统、用户界面和应用程序等;GPU用于负责显示内容的渲染和绘制;调制解调器用于处理无线通信。可以理解的是,上述调制解调器也可以不集成到处理器102中,单独通过一块通信芯片进行实现。
存储器104可以包括随机存储器(Random Access Memory,RAM),也可以包括只读存储器(Read-Only Memory)。存储器104可用于存储指令、程序、代码、代码集或指令集。存储器104可包括存储程序区和存储数据区,其中,存储程序区可存储用于实现操作系统的指令、用于实现至少一个功能的指令、用于实现下述各个方法实施例的指令等。存储数据区还可以存储电子设备100在使用中所获取的数据(如,训练数据或待处理图像)等。
电子设备100还可以包括网络模块以及屏幕,网络模块用于接收以及发送电磁波,实现电磁波与电信号的相互转换,从而与通讯网络或者其他设备进行通讯,例如和音频播放设备进行通讯。网络模块可包括各种现有的用于执行这些功能的电路元件,例如,天线、射频收发器、数字信号处理器、加密/解密芯片、用户身份模块(SIM)卡、存储器等等。网络模块可与各种网络如互联网、企业内部网、无线网络进行通讯或者通过无线网络与其他设备进行通讯。上述的无线网络可包括蜂窝式电话网、无线局域网或者城域网。屏幕可以进行界面内容的显示以及进行数据交互。
在一些实施例中,电子设备100还可以包括有:外设接口106和至少一个外围设备。处理器102、存储器104和外设接口106之间可以通过总线或信号线相连。各个外围设备可以通过总线、信号线或电路板与外设接口连接。具体地,外围设备包括:射频组件108等。
外设接口106可被用于将I/O(Input/Output,输入/输出)相关的至少一个外围设备连接到处理器102和存储器104。在一些实施例中,处理器102、存储器104和外设接口106被集成在同一芯片或电路板上;在一些其他实施例中,处理器102、存储器104和外设接口106中的任意一个或两个可以在单独的芯片或电路板上实现,本申请实施例对此不加以限定。
射频组件108用于接收和发射RF(Radio Frequency,射频)信号,也称电磁信号。射频组件108通过电磁信号与通信网络以及其他通信设备进行通信。射频组件108将电信号转换为电磁信号进行发送,或者,将接收到的电磁信号转换为电信号。可选地,射频组件108包括:天线系统、RF收发器、一个或多个放大器、调谐器、振荡器、数字信号处理器、编解码芯片组、用户身份模块卡等等。射频组件108可以通过至少一种无线通信协议来与其它终端进行通信。该无线通信协议包括但不限于:万维网、城域网、内联网、各代移动通信网络(2G、3G、4G及5G)、无线局域网和/或WiFi(Wireless Fidelity,无线保真)网络。在一些实施例中,射频组件108还可以包括NFC(Near Field Communication,近距离无线通信)有关的电路,本申请对此不加以限定。
本申请实施例还提供一种计算机可读存储介质的结构框图。该计算机可读介质中存储有程序代码,程序代码可被处理器调用执行上述方法实施例中所描述的方法。
计算机可读存储介质可以是诸如闪存、EEPROM(电可擦除可编程只读存储器)、EPROM、硬盘或者ROM之类的电子存储器。可选地,计算机可读存储介质包括非易失性计算机可读介质(non-transitory computer-readable storage medium)。计算机可读存储介质具有执行上述方法中的任何方法步骤的程序代码的存储空间。这些程序代码可以从一个或者多个计算机程序产品中读出或者写入到这一个或者多个计算机程序产品中。程序代码可以例如以适当形式进行压缩。
本申请实施例还提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。电子设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该电子设备执行上述各种可选实现方式中描述的方法。
最后应说明的是:以上实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不驱使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围。

Claims (16)

1.一种分类模型训练方法,其特征在于,所述方法包括:
获取训练数据,所述训练数据包括多张样本图像,所述样本图像在多个维度中至少一个维度下具有样本类别标签;
利用分类模型中的图像特征提取网络对所述样本图像进行特征提取,得到样本图像特征;
利用分类模型中的提示生成网络基于所述样本图像特征,生成与多个提示文本分别对应的样本提示特征;
利用分类模型中的文本特征提取网络基于所述多个提示文本分别对应的参考提示特征和所述多个提示文本分别对应的样本提示特征,生成所述多个提示文本在所述多个维度中每个维度下的样本融合特征;
基于所述多个提示文本在所述多个维度中每个维度下的所述样本融合特征、所述样本图像特征以及所述样本图像在至少一个维度下的样本类别标签,确定第一模型损失;
基于所述第一模型损失调整所述分类模型的模型参数以及所述多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。
2.根据权利要求1所述的方法,其特征在于,所述利用分类模型的文本特征提取网络基于所述多个提示文本分别对应的参考提示特征和所述多个提示文本分别对应的样本提示特征,生成所述多个提示文本在所述多个维度中每个维度下的样本融合特征,包括:
将每个所述提示文本的参考提示特征与所述提示文本对应的样本提示特征进行融合,得到各所述提示文本对应的样本融合提示特征;
利用分类模型的文本特征提取网络对各所述提示文本对应的样本融合提示特征进行特征提取,得到各提示文本在所述多个维度中每个维度下的样本融合特征。
3.根据权利要求1所述的方法,其特征在于,所述方法还包括:
基于相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度,确定特征互斥损失;
基于所述特征互斥损失调整所述提示文本。
4.根据权利要求3所述的方法,其特征在于,基于相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度,确定特征互斥损失,包括:
计算相同维度下多个提示文本各自对应的样本融合特征中每两个样本融合特征之间的相似度;
对所述相似度进行求和,得到特征互斥损失。
5.根据权利要求1所述的方法,其特征在于,所述基于所述多个提示文本在所述多个维度中每个维度下的所述样本融合特征、所述样本图像特征以及所述样本图像在至少一个维度下的样本类别标签,确定第一模型损失,包括:
计算各所述提示文本在每个维度下的样本融合特征与所述样本图像特征之间的特征相似度;
基于同一维度下各提示文本对应的样本融合特征与所述样本图像特征之间的特征相似度,确定样本图像在每个维度下的预测类别;
基于所述样本图像在至少一个维度下的样本类别标签以及样本图像在每个维度下的预测类别,确定第一模型损失。
6.根据权利要求5所述的方法,其特征在于,所述基于所述样本图像在至少一个维度下的样本类别标签以及样本图像在每个维度下的预测类别,确定第一模型损失,包括:
基于样本图像在至少一个维度下的类别标签以及样本图像在每个维度下的预测类别进行交叉熵损失计算,得到第一模型损失。
7.根据权利要求1所述的方法,其特征在于,基于所述第一模型损失调整所述分类模型的模型参数以及提示文本,包括:
基于所述第一模型损失调整所述提示生成网络的模型参数以及提示文本。
8.根据权利要求1所述的方法,其特征在于,所述利用分类模型的图像特征提取网络对所述样本图像进行特征提取得到样本图像特征之前,所述方法还包括:
获取多个图像文本对,所述图像文本对包括图像样本和文本样本;
将所述多个文本样本输入至文本特征提取网络,并将所述多个图像样本输入至图像特征提取网络,进行对比学习处理,得到属于同一图像文本对的图像样本和文本样本之间的相似度,以及属于不同图像文本对的图像样本和文本样本之间的相似度;
基于属于同一图像文本对的图像样本和文本样本之间的相似度,以及属于不同图像样本对的图像样本和文本样本之间的相似度,获得第二模型损失;
基于所述第二模型损失更新所述文本特征提取网络和所述图像特征提取网络,直至达到预训练结束条件。
9.一种图像分类方法,其特征在于,所述方法包括:
获取待处理图像;
利用分类模型中的图像特征提取网络对所述待处理图像进行特征提取,得到目标图像特征;
利用分类模型中的提示生成网络基于所述目标图像特征生成与多个提示文本分别对应的目标提示特征,所述多个提示文本用于从多个角度描述所述待处理图像中的对象;
利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的目标提示特征,生成所述多个提示文本在所述多个维度中每个维度下的目标融合特征;
基于目标图像特征和所述多个提示文本在所述多个维度中每个维度下的目标融合特征,确定待处理图像在每个维度下的类别。
10.根据权利要求9所述的方法,其特征在于,基于各提示文本在所述多个维度中每个维度下的目标融合特征,确定待处理图像在每个维度下的类别,包括:
针对每个维度,基于所述目标图像特征和各提示文本在该维度下的融合特征,确定该维度下各提示文本的类别预测结果;对该维度下的各提示文本的类别预测结果进行加权求和,得到待处理图像在该维度下的类别。
11.根据权利要求9所述的方法,其特征在于,所述针对每个维度,基于所述目标图像特征和各提示文本在该维度下的融合特征,确定该维度下各提示文本的类别预测结果,包括:
计算每个目标融合特征与目标图像特征之间的相似度;
针对每个维度下的每个提示文本对应的目标融合特征与目标图像特征之间的相似度,确定每个维度下与每个提示文本对应的类别预测结果。
12.一种分类模型训练装置,其特征在于,所述装置包括:
数据获取模块,用于获取训练数据,所述训练数据包括多张样本图像,所述样本图像在多个维度中至少一个维度下具有样本类别标签;
第一图像特征提取模块,用于利用分类模型中的图像特征提取网络对所述样本图像进行特征提取,得到样本图像特征;
第一提示特征生成模块,用于利用分类模型中的提示生成网络基于所述样本图像特征,生成与多个提示文本分别对应的样本提示特征;
第一融合特征生成模块,用于利用分类模型中的文本特征提取网络基于所述多个提示文本分别对应的参考提示特征和所述多个提示文本分别对应的样本提示特征,生成所述多个提示文本在所述多个维度中每个维度下的样本融合特征;
损失确定模块,用于基于所述多个提示文本在所述多个维度中每个维度下的所述样本融合特征、所述样本图像特征以及所述样本图像在至少一个维度下的样本类别标签,确定第一模型损失;
模型训练模块,用于基于所述第一模型损失调整所述分类模型的模型参数以及所述多个提示文本,调整后的提示文本用于从多个角度描述图像中的对象。
13.一种图像分类装置,其特征在于,所述装置包括:
图像获取模块,用于获取待处理图像;
第二图像特征提取模块,用于利用分类模型中的图像特征提取网络对所述待处理图像进行特征提取,得到目标图像特征;;
第二提示特征生成模块,用于利用分类模型中的提示生成网络基于所述目标图像特征生成与多个提示文本分别对应的目标提示特征,所述多个提示文本用于从多个角度描述所述待处理图像中的对象;
第二融合特征生成模块,用于利用分类模型中的文本特征提取网络基于多个提示文本分别对应的参考提示特征和多个提示文本分别对应的目标提示特征,生成所述多个提示文本在所述多个维度中每个维度下的目标融合特征;
类别确定模块,用于基于目标图像特征和所述多个提示文本在所述多个维度中每个维度下的目标融合特征,确定待处理图像在每个维度下的类别。
14.一种电子设备,其特征在于,包括:
一个或多个处理器;
存储器;
一个或多个程序,其中所述一个或多个程序被存储在所述存储器中并被配置为由所述一个或多个处理器执行,所述一个或多个程序配置用于执行如权利要求1-8或9-11中任意一项所述的方法。
15.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有程序代码,所述程序代码可被处理器调用执行如权利要求1-8或9-11中任意一项所述的方法。
16.一种计算机程序产品,包括计算机程序/指令,其特征在于,该计算机程序/指令被处理器执行时实现权利要求1-8或9-11中任意一项所述方法的步骤。
CN202311279540.3A 2023-09-27 2023-09-27 分类模型训练方法、分类方法、装置及电子设备 Pending CN117315685A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311279540.3A CN117315685A (zh) 2023-09-27 2023-09-27 分类模型训练方法、分类方法、装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311279540.3A CN117315685A (zh) 2023-09-27 2023-09-27 分类模型训练方法、分类方法、装置及电子设备

Publications (1)

Publication Number Publication Date
CN117315685A true CN117315685A (zh) 2023-12-29

Family

ID=89249482

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311279540.3A Pending CN117315685A (zh) 2023-09-27 2023-09-27 分类模型训练方法、分类方法、装置及电子设备

Country Status (1)

Country Link
CN (1) CN117315685A (zh)

Similar Documents

Publication Publication Date Title
CN112164391B (zh) 语句处理方法、装置、电子设备及存储介质
CN113627447B (zh) 标签识别方法、装置、计算机设备、存储介质及程序产品
CN112395979B (zh) 基于图像的健康状态识别方法、装置、设备及存储介质
CN112529149B (zh) 一种数据处理方法及相关装置
CN116861995A (zh) 多模态预训练模型的训练及多模态数据处理方法和装置
CN113434716A (zh) 一种跨模态信息检索方法和装置
CN112686023A (zh) 文本数据处理方法、装置、电子设备及存储介质
CN111291695B (zh) 人员违章行为识别模型训练方法、识别方法及计算机设备
CN114581710A (zh) 图像识别方法、装置、设备、可读存储介质及程序产品
CN114299304A (zh) 一种图像处理方法及相关设备
CN116704269B (zh) 数据处理方法、装置、设备及存储介质
CN114155388B (zh) 一种图像识别方法、装置、计算机设备和存储介质
CN113240071B (zh) 图神经网络处理方法、装置、计算机设备及存储介质
CN117315685A (zh) 分类模型训练方法、分类方法、装置及电子设备
CN111339786B (zh) 语音处理方法、装置、电子设备及存储介质
CN117009577A (zh) 一种视频数据处理方法、装置、设备及可读存储介质
CN112132269B (zh) 模型处理方法、装置、设备及存储介质
CN115905605A (zh) 一种数据处理方法、设备以及计算机可读存储介质
CN113569094A (zh) 视频推荐方法、装置、电子设备及存储介质
CN114372205B (zh) 特征量化模型的训练方法、装置以及设备
CN117392260B (zh) 一种图像生成方法及装置
CN117152567B (zh) 特征提取网络的训练方法、分类方法、装置及电子设备
CN113033212B (zh) 文本数据处理方法及装置
CN117475340A (zh) 视频数据处理方法、装置、计算机设备和存储介质
CN116977655A (zh) 图像处理方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication