CN115841596B - 多标签图像分类方法及其模型的训练方法、装置 - Google Patents

多标签图像分类方法及其模型的训练方法、装置 Download PDF

Info

Publication number
CN115841596B
CN115841596B CN202211626780.1A CN202211626780A CN115841596B CN 115841596 B CN115841596 B CN 115841596B CN 202211626780 A CN202211626780 A CN 202211626780A CN 115841596 B CN115841596 B CN 115841596B
Authority
CN
China
Prior art keywords
label
loss
image
classification
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202211626780.1A
Other languages
English (en)
Other versions
CN115841596A (zh
Inventor
王晓梅
卢云鸿
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Huayuan Computing Technology Shanghai Co ltd
Original Assignee
Huayuan Computing Technology Shanghai Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Huayuan Computing Technology Shanghai Co ltd filed Critical Huayuan Computing Technology Shanghai Co ltd
Priority to CN202211626780.1A priority Critical patent/CN115841596B/zh
Publication of CN115841596A publication Critical patent/CN115841596A/zh
Priority to PCT/CN2023/090036 priority patent/WO2024124770A1/zh
Application granted granted Critical
Publication of CN115841596B publication Critical patent/CN115841596B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Software Systems (AREA)
  • General Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Databases & Information Systems (AREA)
  • Medical Informatics (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种多标签图像分类方法及其模型的训练方法、装置。该模型训练方法包括:获取训练样本;获取标签集的语义特征;将样本图像和语义特征输入多标签图像分类模型,从样本图像中获取视觉特征,基于对比学习网络得到样本图像和标签集的相关度,基于全连接网络得到样本图像映射到分类标签的预测结果;构建相关度矩阵和单位矩阵;基于相关度矩阵、单位矩阵、预测结果和对应的分类标签计算目标损失,调整多标签图像分类模型的参数,直至满足收敛条件。本发明将标签中的语义信息利用起来,除了训练图像信息对应标签的预测结果,还学习图像信息与语义信息的相关性,得到一个高准确度的多标签图像分类模型,提高了多标签图像分类的准确性。

Description

多标签图像分类方法及其模型的训练方法、装置
技术领域
本发明涉及图像分类技术领域,特别涉及一种多标签图像分类方法及其模型的训练方法、装置。
背景技术
图像分类是指从预先设定的多个类别标签中为图像分配所属类别标签的过程,可以包括单标签分类和多标签分类。单标签分类是给每个图像分配一个最能表达图像内容的单个标签;多标签分类可以给每个图像分配对应的多个类别标签,相比起来,多标签图像分类技术更能够充分表达图像内容,应用更加广泛。然而现有的图像分类技术大多只根据图像内容进行分类,对于多标签图像分类来说,容易产生错识别和漏识别标签,准确度不高。
发明内容
本发明要解决的技术问题是为了克服现有技术中多标签图像分类准确度低的缺陷,提供一种多标签图像分类方法及其模型的训练方法、装置。
本发明是通过下述技术方案来解决上述技术问题:
根据本发明的第一方面,提供一种多标签图像分类模型的训练方法,其特征在于,所述多标签图像分类模型的训练方法包括:
获取多个训练样本;所述训练样本包括样本图像和与所述样本图像对应的标签集,所述标签集包括至少两个分类标签;
从所述分类标签中获取标签集的语义特征;
将所述样本图像和所述标签集的语义特征输入多标签图像分类模型,所述多标签图像分类模型基于视觉编码器从所述样本图像中获取视觉特征,并基于对比学习网络对所述视觉特征和所述语义特征进行处理得到所述样本图像和所述标签集的相关度,以及基于全连接网络对所述视觉特征进行处理得到所述样本图像映射到所述分类标签的预测结果;
根据所述相关度构建相关度矩阵和与所述相关度矩阵对应的单位矩阵;
基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的分类标签计算目标损失,并根据所述目标损失调整所述多标签图像分类模型的参数,直至满足收敛条件。
较佳地,所述视觉编码器包括卷积神经网络和Transformer(自注意力)网络,所述基于视觉编码器从所述样本图像中获取视觉特征的步骤包括:
将所述样本图像输入所述卷积神经网络进行特征提取,得到对应的特征图;
将所述特征图拉平得到第一特征;
将所述第一特征输入所述Transformer网络进行编码处理,得到第二特征;
根据所述第二特征生成所述视觉特征。
较佳地,所述从所述分类标签中获取标签集的语义特征的步骤包括:
分别将所述分类标签输入语义编码器进行语义映射,得到所述分类标签的语义特征;
获取每个分类标签的权重参数;
根据所述权重参数和所述分类标签的语义特征生成所述标签集的语义特征。
较佳地,所述基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的分类标签计算目标损失的步骤包括:
基于所述相关度矩阵和所述单位矩阵计算第一损失;
基于所述预测结果和对应的分类标签计算第二损失;
根据所述第一损失和所述第二损失确定所述目标损失。
较佳地,所述第一损失包括行损失和列损失,所述基于所述相关度矩阵和所述单位矩阵计算第一损失的步骤包括:
将所述相关度矩阵和所述单位矩阵代入第一损失函数,得到所述行损失;所述第一损失函数包括 其中Lx表示行损失,Mij表示相关度矩阵的第i行第j列,eij表示单位矩阵的第i行第j列,σ()表示激活函数,N表示样本图像的数量;
将所述相关度矩阵和所述单位矩阵代入第二损失函数,得到所述列损失;所述第二损失函数包括 其中Ly表示行损失,Mji表示相关度矩阵的第j行第i列,eji表示单位矩阵的第j行第i列,σ()表示激活函数,N表示样本图像的数量;
根据所述行损失和所述列损失确定所述第一损失;
和/或,
所述基于所述预测结果和对应的分类标签计算第二损失的步骤包括:
将所述预测结果和对应的分类标签代入第三损失函数,得到所述第二损失;所述第三损失函数包括 其中,Lv表示第二损失,gkt表示第k个样本图像映射到第t个分类标签的真实标签值,/>表示第k个样本图像映射到第t个分类标签的预测结果值,σ()表示激活函数,N表示样本图像的数量,T表示分类标签的数量。
根据本发明的第二方面,提供一种多标签图像分类方法,所述多标签图像分类方法包括:
获取待分类图像和至少两个候选标签;
从所述候选标签中获取所述候选标签的语义特征;
将所述待分类图像和所述候选标签的语义特征输入多标签图像分类模型进行分类处理,得到所述待分类图像和所述候选标签的相关度以及所述待分类图像映射到所述候选标签的预测结果;
其中,所述多标签图像分类模型通过本发明的多标签图像分类模型的训练方法得到;
根据所述相关度和所述预测结果确定所述待分类图像的分类结果。
根据本发明的第三方面,提供一种多标签图像分类模型的训练装置,所述多标签图像分类模型的训练装置包括:
第一获取模块,用于获取多个训练样本;所述训练样本包括样本图像和与所述样本图像对应的标签集,所述标签集包括至少两个分类标签;
第二获取模块,用于从所述分类标签中获取标签集的语义特征;
第一分类模块,用于将所述样本图像和所述标签集的语义特征输入多标签图像分类模型,所述多标签图像分类模型基于视觉编码器从所述样本图像中获取视觉特征,并基于对比学习网络对所述视觉特征和所述语义特征进行处理得到所述样本图像和所述标签集的相关度,以及基于全连接网络对所述视觉特征进行处理得到所述样本图像映射到所述分类标签的预测结果;
构建模块,用于根据所述相关度构建相关度矩阵和与所述相关度矩阵对应的单位矩阵;
训练模块,用于基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的分类标签计算目标损失,并根据所述目标损失调整所述多标签图像分类模型的参数,直至满足收敛条件。
根据本发明的第四方面,提供一种多标签图像分类装置,所述多标签图像分类装置包括:
第三获取模块,用于获取待分类图像和至少两个候选标签;
第四获取模块,用于从所述候选标签中获取所述候选标签的语义特征;
第二分类模块,用于将所述待分类图像和所述候选标签的语义特征输入多标签图像分类模型进行分类处理,得到所述待分类图像和所述候选标签的相关度以及所述待分类图像映射到所述候选标签的预测结果;
其中,所述多标签图像分类模型通过本发明的多标签图像分类模型的训练装置得到;
分类结果确定模块,用于根据所述相关度和所述预测结果确定所述待分类图像的分类结果。
根据本发明的第五方面,提供一种电子设备,包括存储器、处理器及存储在存储器上并用于在处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现本发明的多标签图像分类模型的训练方法或者本发明的多标签图像分类方法。
根据本发明的第六方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现本发明的多标签图像分类模型的训练方法或者本发明的多标签图像分类方法。
在符合本领域常识的基础上,上述各优选条件,可任意组合,即得本发明各较佳实例。
本发明的积极进步效果在于:
将标签中的语义信息利用起来,预先得到图像对应标签集的语义信息,在训练模型时,除了训练图像信息对应标签的预测结果以外,还学习图像信息与标签集的语义信息的相关性,进而得到一个高准确度的多标签图像分类模型,有效地提高了多标签图像分类的准确性。另外,通过多种损失函数联合优化模型,保证图像分类模型能收敛至最优解,进一步提高了多标签图像分类的准确性。
附图说明
图1为本发明实施例1的多标签图像分类模型的训练方法的流程示意图。
图2为本发明实施例1的多标签图像分类模型的训练方法中语义编码器获取语义特征的框架示意图。
图3为本发明实施例1的多标签图像分类模型的训练方法中视觉编码器获取视觉特征的框架示意图。
图4为本发明实施例1的多标签图像分类模型的训练方法中训练对比学习网络的框架示意图。
图5为本发明实施例2的多标签图像分类方法的流程示意图。
图6为本发明实施例2的多标签图像分类方法中语义编码器获取语义特征的框架示意图。
图7为本发明实施例2的多标签图像分类方法的框架示意图。
图8为本发明实施例2的多标签图像分类方法中通过对比学习网络得到相关度的框架示意图。
图9本发明实施例3的多标签图像分类模型的训练装置的结构示意图。
图10为本发明实施例4的多标签图像分类装置的结构示意图。
图11为本发明实施例5的电子设备的结构示意图。
具体实施方式
下面通过实施例的方式进一步说明本发明,但并不因此将本发明限制在所述的实施例范围之中。
实施例1
本实施例提供一种多标签图像分类模型的训练方法,如图1所示,该多标签图像分类模型的训练方法包括以下步骤
S11、获取多个训练样本。
其中,训练样本包括样本图像和与样本图像对应的标签集,标签集包括至少两个分类标签。
在本实施例中,一个训练样本包括一张样本图像和一个与样本图像对应的标签集。样本图像可以是设备拍摄或网络爬取的,样本图像中包括多个不同类别的图像信息,比如某张风景图像中包括太阳、云、山和小河等等图像信息,对样本图像中不同的图像信息分别标注标签(即分类标签),从而得到与样本图像对应的标签集。
作为可选的一种实施方式,为了更好的提取样本图像的特征,对样本图像进行预处理,比如将所有样本图像进行短边尺寸相同的缩放,然后对样本图像的像素点进行归一化操作,即把像素点对应到0到1之间,从而消除图像特征中单位和尺度差异的影响,又比如做一些色彩抖动操作,通过调整图像的亮度,饱和度和对比度来对图像的颜色方面做增强,以消除图像在不同背景中存在的差异性等等,当然本实施例并不限于上述图像预处理操作。
S12、从分类标签中获取标签集的语义特征。
在本实施例中,需要对标签集进行编码,也即获取标签集的语义特征,作为可选的一种实施方式,分别将分类标签输入语义编码器进行语义映射,得到分类标签的语义特征;再获取每个分类标签的权重参数;最后根据权重参数和分类标签的语义特征生成标签集的语义特征。
举例说明,如图2所示,假设某张图像对应的标签集包括t个分类标签,则将这k个分类标签输入语义编码器进行语义映射,可得到维度为t×d的语义特征。作为可选的一种实施方式,语义映射的方法可以采用Glove、Bert、ELMo和word2vec(一些词向量经典模型)等用来映射每个词而产生语义特征的相关模型。由于语义编码器在进行语义映射时,还包括猜测相邻位置的词,因此语义特征还用于表征分类标签与分类标签之间的相关性。
作为可选的一种实施方式,为了得到标签集整体的语义特征,语义编码器还用于获取每个分类标签的权重参数,参见图2,采用加权平均的方式将t×d的语义特征的融合为维度为1×d的语义特征,也即每个分类标签的语义特征乘以对应的权重参数,再取平均值得到标签集的语义特征,其中,标签集下的所有分类标签的权重参数之后为1。
S13、将样本图像和标签集的语义特征输入多标签图像分类模型,多标签图像分类模型基于视觉编码器从样本图像中获取视觉特征,并基于对比学习网络对视觉特征和语义特征进行处理得到样本图像和标签集的相关度,以及基于全连接网络对视觉特征进行处理得到样本图像映射到分类标签的预测结果。
在本实施例中,还需要对样本图像进行编码,也即获取样本图像的视觉特征,可通过多标签图像分类模型的视觉编码器来从样本图像中获取语义特征。作为可选的一种实施方式,视觉编码器包括卷积神经网络和Transformer网络,可将样本图像输入卷积神经网络进行特征提取,得到对应的特征图;然后将特征图拉平得到第一特征;再将第一特征输入Transformer网络进行编码处理,得到第二特征;最后根据第二特征生成视觉特征。
举例说明,如图3所示,将H×W×C的样本图像输入到卷积神经网络进行特征提取,其中,H表示样本图像的高,W表示样本图像的宽,C表示样本图像的通道数,一般C=3,也即对应R、G、B三个通道,当然如果是灰度图像,就是单通道,通道数C=1。卷积神经网络包括多层卷积层,通过卷积运算以分层形式提取样本图像的局部特征,获取最后一层卷积层的输出,可得到特征维度为c×c×d的特征图。作为可选的一种实施方式,卷积神经网络可以采用VGG网络(一种经典卷积神经网络)和Resnet网络(残差网络)等结构。
参见图3,将特征维度为c×c×d的特征图拉平为k×d的第一特征,其中,k=c2,之后将第一特征送入到Transformer网络中进行进一步的编码处理,得到特征维度为k×d的第二特征,Transformer网络可以反映复杂的空间变换和视觉元素之间的远距离关系等,从而构成全局表征。在本实施例中,视觉编码器通过卷积神经网络和Transformer网络输出的特征,能够同时考虑到局部特征细节和全局表征,更有利于后续的图像分类。
作为可选的一种实施方式,对k×d的第二特征取平均值得到样本图像的视觉特征,其视觉特征的维度为1×d。
需要说明的是,本实施例的语义编码器是预训练的,而视觉编码器是在训练多标签图像分类模型中学习出来的。
在本实施例中,多标签图像分类模型中引入对比学习网络和全连接网络,在得到视觉特征和语义特征后,分为两个支路,一路是将视觉特征和语义特征输入到对比学习网络,并基于对比学习网络对视觉特征和语义特征进行处理,以得到样本图像和标签集的相关度;另一支路是将视觉特征送入到全连接网络进行全连接映射,从而基于全连接网络对视觉特征进行处理得到样本图像映射到分类标签的预测结果(即每个分类标签的预测概率)。
作为可选的一种实施方式,对比学习网络在训练阶段的核心内容如图4所示,假设有N个训练样本,则包含N个样本图像和对应的N个标签集,将N×d的视觉特征X和N×d的语义特征Y一一相乘并获得其余弦相似性(即样本图像和标签集的相关度),余弦相似性用xi·yj表示,xi·yj=|xi||yj|cosθ,其中,xi表示第i个视觉特征,yj表示第j个语义特征。比如,假设有两个训练样本,则得到的余弦相似性有x1·y1、x1·y2、x2·y1和x2·y2
S14、根据相关度构建相关度矩阵和与相关度矩阵对应的单位矩阵。
其中,相关度矩阵如图4所示,成匹配对的视觉特征和语义特征设置在从左上角到右下角的对角线上。由于在训练时,需要最大化样本图像的视觉特征与对应匹配的标签集的语义特征的余弦相似性,同时最小化非匹配对的视觉特征和语义特征的余弦相似性,因此还需要构建单位矩阵,单位矩阵具有如下特性,即从左上角到右下角的对角线上的元素均为1,除此以外全都为0,并且任何矩阵与单位矩阵相乘都等于自身,具有特殊的作用。
举例说明,假设有N个视觉特征和对应的N个语义特征,则得到N×N个余弦相似性,从而构建得到维度为N×N的余弦相似性矩阵(即相关度矩阵)和维度为N×N的单位矩阵。
S15、基于相关度矩阵、单位矩阵、预测结果和对应的分类标签计算目标损失,并根据目标损失调整多标签图像分类模型的参数,直至满足收敛条件。
在本实施例中,基于相关度矩阵和单位矩阵计算第一损失,以及基于预测结果和对应的分类标签计算第二损失,最后根据第一损失和第二损失确定目标损失。
作为可选的一种实施方式,第一损失包括行损失和列损失,对语义特征特征和视觉特征分别做交叉熵损失,也即分别从相关度矩阵的行和列上来最大化对角线上的数值,从而分别获得了行损失Lx和列损失Ly。具体地,将相关度矩阵和单位矩阵代入第一损失函数,得到行损失,第一损失函数的计算公式包括:
其中Lx表示行损失,Mij表示相关度矩阵的第i行第j列,eij表示单位矩阵的第i行第j列,σ()表示激活函数,N表示样本图像的数量。
将相关度矩阵和单位矩阵代入第二损失函数,得到列损失,第二损失函数的计算公式包括:
其中Ly表示行损失,Mji表示相关度矩阵的第j行第i列,eji表示单位矩阵的第j行第i列,σ()表示激活函数,N表示样本图像的数量。
最后,根据行损失和列损失确定第一损失,作为可选的一种实施方式,通过行损失和列损失的平均值得到第一损失;作为可选的另一个实施方式,行损失和列损失分别从相关度矩阵的行和列进行遍历并计算,也即遍历的顺序不一样,从理想的计算结果上来说Lx=Ly,只是意义上不一样的,因此也可以将行损失或列损失直接作为第一损失,可以减少计算量。用Lxy表示第一损失,则Lxy=Lx=Ly,或者Lxy=(Lx+Ly)/2,在实际应用中,为了保证模型的准确度,通常取Lxy=(Lx+Ly)/2。
作为可选的一种实施方式,基于预测结果和与样本图像对应的分类标签计算交叉熵损失(即第二损失)。具体地,将预测结果和对应的分类标签代入第三损失函数,得到第二损失,第三损失函数的计算公式包括:
其中,Lv表示第二损失,gkt表示第k个样本图像映射到第t个分类标签的真实标签值,也即“0”或者“1”的标签值,表示第k个样本图像映射到第t个分类标签的预测结果值,也即0到1的预测概率值,σ()表示激活函数,N表示样本图像的数量,T表示分类标签的数量。
作为可选的一种实施方式,训练样本中可能存在样本难以不均衡的问题,比如样本图像复杂,或标签不常见等等,还可以在模型训练过程中,对困难样本赋予更大的权重,使其注重对困难样本的学习。
最后,联合第一损失和第二损失,得到多标签图像分类模型的总损失(即目标损失),通过反向传播目标损失,调整多标签图像分类模型的参数,直至多标签图像分类模型满足收敛条件。作为可选的一种实施方式,收敛条件可以是目标损失小于一个预先设定的阈值,通过预先设定一个比较小的损失阈值,当目标损失小于设定的损失阈值时,确定该多标签图像分类模型收敛;作为可选的另一种实施方式,收敛条件也可以是多标签图像分类模型训练的迭代次数大于一个预设的迭代次数,通过预先设定一个比较大的迭代次数,当多标签图像分类模型训练的迭代次数大于设定的最大迭代次数时,确定该多标签图像分类模型收敛。
作为可选的一种实施方式,设置可学习超参数λ来协调第一损失和第二损失在多标签图像分类模型训练中的比重,用λ表示目标损失,λ为0到1的取值,其计算公式如下:
L=λLv+(1-λ)Lxy
作为可选的一种实施方式,在多标签图像分类模型训练初期,λ的值预设为一个较小的值,随着多标签图像分类模型训练的迭代次数增加,逐渐增大λ的值,并继续进行迭代训练。当然本实施例不限上述的调参方式,可以根据实际情况进行手动调参。
作为可选的一种实施方式,多标签图像分类模型的优化目标为目标损失的最小化,设定目标损失L的目标阈值,通过迭代训练,直至目标损失L的值低于目标阈值,且模型参数变化很小,确定该多标签图像分类模型收敛。
本实施例具有如下有益效果:
本实施例将标签中的语义信息利用起来,预先得到图像对应标签集的语义信息,在训练模型时,除了训练图像信息对应标签的预测结果以外,还学习图像信息与标签集的语义信息的相关性,进而得到一个高准确度的多标签图像分类模型,有效地提高了多标签图像分类的准确性。另外,本实施例通过多种损失函数联合优化模型,保证图像分类模型能收敛至最优解,进一步提高了多标签图像分类的准确性。
实施例2
本实施例提供一种多标签图像分类方法,如图5所示,该多标签图像分类方法包括以下步骤:
S21、获取待分类图像和至少两个候选标签。
S22、从候选标签中获取候选标签的语义特征。
作为可选的一种实施方式,将所有的候选标签输入语义编码器进行语义映射,得到候选标签的语义特征。如图6所示,假设有T个候选标签,则得到维度为T×d的语义特征。
S23、将待分类图像和候选标签的语义特征输入多标签图像分类模型进行分类处理,得到待分类图像和候选标签的相关度以及待分类图像映射到候选标签的预测结果;
其中,多标签图像分类模型通过实施例1的多标签图像分类模型的训练方法得到。
在本实施例中,如图7所示,将一张图像(即待分类图像)输入多标签图像分类模型的视觉编码器中获取视觉特征,将所有的候选标签输入语义编码器中获取语义特征,在获取视觉特征和语义特征后,同样分为两个支路,一路是将视觉特征和语义特征输入到训练好的对比学习网络,获取待分类图像和每个候选标签的相关度;另一支路是将视觉特征送入到全连接网络进行全连接映射,获取样本图像映射到每个分类标签的预测结果(即每个分类标签的预测概率)。
作为可选的一种实施方式,对比学习网络在预测阶段的核心内容如图8所示,将1×d的视觉特征X和T×d的语义特征Y一一相乘,计算得到视觉特征在每一个候选标签上的余弦相似性(即待分类图像和每个候选标签的相关度)。参见图8,假设有T个候选标签,则得到的余弦相似性有x1·y1、x1·y2、……x1·yT
S24、根据相关度和预测结果确定待分类图像的分类结果。
分别得到全连接网络的预测结果和对比学习网络的相关度后,待分类图像的分类结果可以是p=λpv+(1-λ)pd,其中,pv为全连接网络的预测结果,pd为对比学习网络的相关度,λ为权重参数,取值与训练阶段的可学习超参数一致,也可以根据分类结果来调节训练阶段的可学习超参数λ。
作为可选的一种实施方式,可根据实际情况设定分类结果的标签阈值,比如分类结果在0.5以上,表示该候选标签为待分类图像对应的标签。
举例说明,假设待分类图像的候选标签中包括帽子,上衣,裤子,鞋子,围巾和裙子,通过多标签图像分类模型进行分类处理,得到待分类图像和候选标签的相关度分别为[0.1,0.2,0.3,0.7,0.5,0.6],以及得到待分类图像映射到候选标签的预测结果分别为[0.3,0.2,0.4,0.4,0.3,0.6],设定λ的值为0.5,则计算得到分类结果为[0.15,0.2,0.35,0.55,0.4,0.6],其中,这些数值按照顺序依次对应帽子,上衣,裤子,鞋子,围巾和裙子,设数值0.5以上的候选标签为待分类图像的标签。可以看到,传统的全连接网络预测得到待分类图像的标签只有裙子,而本实施例得到待分类图像的标签包括鞋子和裙子。
本实施例具有如下有益效果:
在本实施例中,引入候选标签的语义信息,并且用于进行分类处理的多标签分类模型是通过实施例1的多标签图像分类模型的训练方法得到的,该多标签图像分类模型在应用于具体的分类任务时,除了传统的预测结果外,通过对比学习网络提高图像信息与候选标签的语义信息的相关度,其输出的多标签图像分类结果也会更加准确。
实施例3
本实施例提供一种多标签图像分类模型的训练系统,如图9所示,该多标签图像分类模型的训练系统包括第一获取模块11、第二获取模块12、第一分类模块13、构建模块14和训练模块15。
第一获取模块11用于获取多个训练样本。其中,训练样本包括样本图像和与样本图像对应的标签集,标签集包括至少两个分类标签。
在本实施例中,一个训练样本包括一张样本图像和一个与样本图像对应的标签集。样本图像可以是设备拍摄或网络爬取的,样本图像中包括多个不同类别的图像信息,比如某张风景图像中包括太阳、云、山和小河等等图像信息,对样本图像中不同的图像信息分别标注标签(即分类标签),从而得到与样本图像对应的标签集。
作为可选的一种实施方式,为了更好的提取样本图像的特征,第一获取模块11对样本图像进行预处理,比如将所有样本图像进行短边尺寸相同的缩放,然后对样本图像的像素点进行归一化操作,即把像素点对应到0到1之间,从而消除图像特征中单位和尺度差异的影响,又比如做一些色彩抖动操作,通过调整图像的亮度,饱和度和对比度来对图像的颜色方面做增强,以消除图像在不同背景中存在的差异性等等,当然本实施例并不限于上述图像预处理操作。
第二获取模块12用于从分类标签中获取标签集的语义特征。在本实施例中,需要对标签集进行编码,也即获取标签集的语义特征,作为可选的一种实施方式,第二获取模块12分别将分类标签输入语义编码器进行语义映射,得到分类标签的语义特征;第二获取模块12再获取每个分类标签的权重参数;最后第二获取模块12根据权重参数和分类标签的语义特征生成标签集的语义特征。
举例说明,如图2所示,假设某张图像对应的标签集包括t个分类标签,则将这k个分类标签输入语义编码器进行语义映射,可得到维度为t×d的语义特征。作为可选的一种实施方式,语义映射的方法可以采用Glove、Bert、ELMo和word2vec等用来映射每个词而产生语义特征的相关模型。由于语义编码器在进行语义映射时,还包括猜测相邻位置的词,因此语义特征还用于表征分类标签与分类标签之间的相关性。
作为可选的一种实施方式,为了得到标签集整体的语义特征,语义编码器还用于获取每个分类标签的权重参数,参见图2,第二获取模块12采用加权平均的方式将t×d的语义特征的融合为维度为1×d的语义特征,也即每个分类标签的语义特征乘以对应的权重参数,再取平均值得到标签集的语义特征,其中,标签集下的所有分类标签的权重参数之后为1。
第一分类模块13用于将样本图像和标签集的语义特征输入多标签图像分类模型,多标签图像分类模型基于视觉编码器从样本图像中获取视觉特征,并基于对比学习网络对视觉特征和语义特征进行处理得到样本图像和标签集的相关度,以及基于全连接网络对视觉特征进行处理得到样本图像映射到分类标签的预测结果。
在本实施例中,第一分类模块13还需要对样本图像进行编码,也即获取样本图像的视觉特征,可通过多标签图像分类模型的视觉编码器来从样本图像中获取语义特征。作为可选的一种实施方式,视觉编码器包括卷积神经网络和Transformer网络,可将样本图像输入卷积神经网络进行特征提取,得到对应的特征图;然后将特征图拉平得到第一特征;再将第一特征输入Transformer网络进行编码处理,得到第二特征;最后根据第二特征生成视觉特征。
举例说明,如图3所示,第一分类模块13将H×W×C的样本图像输入到卷积神经网络进行特征提取,其中,H表示样本图像的高,W表示样本图像的宽,C表示样本图像的通道数,一般C=3,也即对应R、G、B三个通道,当然如果是灰度图像,就是单通道,通道数C=1。卷积神经网络包括多层卷积层,通过卷积运算以分层形式提取样本图像的局部特征,第一分类模块13获取最后一层卷积层的输出,可得到特征维度为c×c×d的特征图。作为可选的一种实施方式,卷积神经网络可以采用VGG网络和Resnet网络等结构。
参见图3,第一分类模块13将特征维度为c×c×d的特征图拉平为k×d的第一特征,其中,k=c2,之后第一分类模块13将第一特征送入到Transformer网络中进行进一步的编码处理,得到特征维度为k×d的第二特征,Transformer网络可以反映复杂的空间变换和视觉元素之间的远距离关系等,从而构成全局表征。在本实施例中,视觉编码器通过卷积神经网络和Transformer网络输出的特征,能够同时考虑到局部特征细节和全局表征,更有利于后续的图像分类。
作为可选的一种实施方式,第一分类模块13对k×d的第二特征取平均值得到样本图像的视觉特征,其视觉特征的维度为1×d。
需要说明的是,本实施例的语义编码器是预训练的,而视觉编码器是在训练多标签图像分类模型中学习出来的。
在本实施例中,多标签图像分类模型中引入对比学习网络和全连接网络,在得到视觉特征和语义特征后,第一分类模块13将视觉特征和语义特征输入到对比学习网络,并基于对比学习网络对视觉特征和语义特征进行处理,以得到样本图像和标签集的相关度;同时,第一分类模块13将视觉特征送入到全连接网络进行全连接映射,从而基于全连接网络对视觉特征进行处理得到样本图像映射到分类标签的预测结果(即每个分类标签的预测概率)。
作为可选的一种实施方式,对比学习网络在训练阶段的核心内容如图4所示,假设有N个训练样本,则包含N个样本图像和对应的N个标签集,第一分类模块13将N×d的视觉特征X和N×d的语义特征Y一一相乘并获得其余弦相似性(即样本图像和标签集的相关度),余弦相似性用xi·yj表示,xi·yj=|xi||yj|cosθ,其中,xi表示第i个视觉特征,yj表示第j个语义特征。比如,假设有两个训练样本,则得到的余弦相似性有x1·y1、x1·y2、x2·y1和x2·y2
构建模块14用于根据相关度构建相关度矩阵和与相关度矩阵对应的单位矩阵。其中,相关度矩阵如图4所示,成匹配对的视觉特征和语义特征设置在从左上角到右下角的对角线上。由于在训练时,需要最大化样本图像的视觉特征与对应匹配的标签集的语义特征的余弦相似性,同时最小化非匹配对的视觉特征和语义特征的余弦相似性,因此构建模块14还需要构建单位矩阵,单位矩阵具有如下特性,即从左上角到右下角的对角线上的元素均为1,除此以外全都为0,并且任何矩阵与单位矩阵相乘都等于自身,具有特殊的作用。
举例说明,假设有N个视觉特征和对应的N个语义特征,则得到N×N个余弦相似性,从而构建模块14构建得到维度为N×N的余弦相似性矩阵(即相关度矩阵)和维度为N×N的单位矩阵。
训练模块15用于基于相关度矩阵、单位矩阵、预测结果和对应的分类标签计算目标损失,并根据目标损失调整多标签图像分类模型的参数,直至满足收敛条件。在本实施例中,训练模块15基于相关度矩阵和单位矩阵计算第一损失,以及基于预测结果和对应的分类标签计算第二损失,最后训练模块15根据第一损失和第二损失确定目标损失。
作为可选的一种实施方式,第一损失包括行损失和列损失,训练模块15对语义特征特征和视觉特征分别做交叉熵损失,也即分别从相关度矩阵的行和列上来最大化对角线上的数值,从而分别获得了行损失Lx和列损失Ly。具体地,训练模块15将相关度矩阵和单位矩阵代入第一损失函数,得到行损失,第一损失函数的计算公式包括:
其中Lx表示行损失,Mij表示相关度矩阵的第i行第j列,eij表示单位矩阵的第i行第j列,σ()表示激活函数,N表示样本图像的数量。
训练模块15将相关度矩阵和单位矩阵代入第二损失函数,得到列损失,第二损失函数的计算公式包括:
其中Ly表示行损失,Mji表示相关度矩阵的第j行第i列,eji表示单位矩阵的第j行第i列,σ()表示激活函数,N表示样本图像的数量。
最后,训练模块15根据行损失和列损失确定第一损失,作为可选的一种实施方式,训练模块15通过行损失和列损失的平均值得到第一损失;作为可选的另一个实施方式,行损失和列损失分别从相关度矩阵的行和列进行遍历并计算,也即遍历的顺序不一样,从理想的计算结果上来说Lx=Ly,只是意义上不一样的,因此训练模块15也可以将行损失或列损失直接作为第一损失,可以减少计算量。用Lxy表示第一损失,则Lxy=Lx=Ly,或者Lxy=(Lx+Ly)/2,在实际应用中,为了保证模型的准确度,通常取Lxy=(Lx+Ly)/2。
作为可选的一种实施方式,训练模块15基于预测结果和与样本图像对应的分类标签计算交叉熵损失(即第二损失)。具体地,训练模块15将预测结果和对应的分类标签代入第三损失函数,得到第二损失,第三损失函数的计算公式包括:
其中,Lv表示第二损失,gkt表示第k个样本图像映射到第t个分类标签的真实标签值,也即“0”或者“1”的标签值,表示第k个样本图像映射到第t个分类标签的预测结果值,也即0到1的预测概率值,σ()表示激活函数,N表示样本图像的数量,T表示分类标签的数量。
作为可选的一种实施方式,训练样本中可能存在样本难以不均衡的问题,比如样本图像复杂,或标签不常见等等,还可以在模型训练过程中,对困难样本赋予更大的权重,使其注重对困难样本的学习。
最后,训练模块15联合第一损失和第二损失,得到多标签图像分类模型的总损失(即目标损失),训练模块15通过反向传播目标损失,调整多标签图像分类模型的参数,直至多标签图像分类模型满足收敛条件。作为可选的一种实施方式,收敛条件可以是目标损失小于一个预先设定的阈值,通过预先设定一个比较小的损失阈值,当目标损失小于设定的损失阈值时,确定该多标签图像分类模型收敛;作为可选的另一种实施方式,收敛条件也可以是多标签图像分类模型训练的迭代次数大于一个预设的迭代次数,通过预先设定一个比较大的迭代次数,当多标签图像分类模型训练的迭代次数大于设定的最大迭代次数时,确定该多标签图像分类模型收敛。
作为可选的一种实施方式,训练模块15设置可学习超参数λ来协调第一损失和第二损失在多标签图像分类模型训练中的比重,用λ表示目标损失,λ为0到1的取值,其计算公式如下:
L=λLv+(1-λ)Lxy
作为可选的一种实施方式,在多标签图像分类模型训练初期,λ的值预设为一个较小的值,随着多标签图像分类模型训练的迭代次数增加,训练模块15逐渐增大λ的值,并继续进行迭代训练。当然本实施例不限上述的调参方式,可以根据实际情况进行手动调参。
作为可选的一种实施方式,多标签图像分类模型的优化目标为目标损失的最小化,训练模块15设定目标损失L的目标阈值,通过迭代训练,直至目标损失L的值低于目标阈值,且模型参数变化很小,确定该多标签图像分类模型收敛。
实施例4
本实施例提供一种多标签图像分类系统,如图10所示,该多标签图像分类系统包括第三获取模块21、第四获取模块22、第二分类模块23和分类结果确定模块24。
第三获取模块21用于获取待分类图像和至少两个候选标签。
第四获取模块22用于从候选标签中获取候选标签的语义特征。作为可选的一种实施方式,第四获取模块22将所有的候选标签输入语义编码器进行语义映射,得到候选标签的语义特征。如图6所示,假设有T个候选标签,则得到维度为T×d的语义特征。
第二分类模块23用于将待分类图像和候选标签的语义特征输入多标签图像分类模型进行分类处理,得到待分类图像和候选标签的相关度以及待分类图像映射到候选标签的预测结果。其中,多标签图像分类模型通过实施例3的多标签图像分类模型的训练系统得到。
在本实施例中,如图7所示,第二分类模块23将一张图像(即待分类图像)输入多标签图像分类模型的视觉编码器中获取视觉特征,第四获取模块22将所有的候选标签输入语义编码器中获取语义特征,第二分类模块23将视觉特征和语义特征输入到训练好的对比学习网络,获取待分类图像和每个候选标签的相关度pd;同时,第二分类模块23将视觉特征送入到全连接网络进行全连接映射,获取样本图像映射到每个分类标签的预测结果pv(即每个分类标签的预测概率)。
作为可选的一种实施方式,对比学习网络在预测阶段的核心内容如图8所示,将1×d的视觉特征X和T×d的语义特征Y一一相乘,第二分类模块23计算得到视觉特征在每一个候选标签上的余弦相似性(即待分类图像和每个候选标签的相关度)。参见图8,假设有T个候选标签,则得到的余弦相似性有x1·y1、x1·y2、……x1·yT
分类结果确定模块24用于根据相关度和预测结果确定待分类图像的分类结果。分别得到全连接网络的预测结果和对比学习网络的相关度后,待分类图像的分类结果可以是p=λpv+(1-λ)pd,其中,pv为全连接网络的预测结果,pd为对比学习网络的相关度,λ为权重参数,取值与训练阶段的可学习超参数一致,也可以根据分类结果来调节训练阶段的可学习超参数λ。
作为可选的一种实施方式,分类结果确定模块24可根据实际情况设定分类结果的标签阈值,比如分类结果在0.5以上,表示该候选标签为待分类图像对应的标签。
举例说明,假设待分类图像的候选标签中包括帽子,上衣,裤子,鞋子,围巾和裙子,通过多标签图像分类模型进行分类处理,得到待分类图像和候选标签的相关度分别为[0.1,0.2,0.3,0.7,0.5,0.6],以及得到待分类图像映射到候选标签的预测结果分别为[0.3,0.2,0.4,0.4,0.3,0.6],设定λ的值为0.5,则计算得到分类结果为[0.15,0.2,0.35,0.55,0.4,0.6],其中,这些数值按照顺序依次对应帽子,上衣,裤子,鞋子,围巾和裙子,设数值0.5以上的候选标签为待分类图像的标签。可以看到,传统的全连接网络预测得到待分类图像的标签只有裙子,而本实施例得到待分类图像的标签包括鞋子和裙子。
实施例5
本实施例提供一种电子设备,所述电子设备包括存储器、处理器及存储在存储器上并用于在处理器上运行的计算机程序,所述处理器执行所述程序时实现实施例1的多标签图像分类模型的训练方法以及实施例2的多标签图像分类方法。
如图11所示的电子设备30仅仅是一个示例,不应对本发明实施例的功能和使用范围带来任何限制。
电子设备30可以以通用计算设备的形式表现,例如其可以为服务器设备。电子设备30的组件可以包括但不限于:上述至少一个处理器31、上述至少一个存储器32、连接不同系统组件(包括存储器32和处理器31)的总线33。
总线33包括数据总线、地址总线和控制总线。
存储器32可以包括易失性存储器,例如随机存取存储器(RAM)321和高速缓存存储器322,还可以进一步包括只读存储器(ROM)323。
存储器32还可以包括具有一组(至少一个)程序模块324的程序工具325,这样的程序模块324包括但不限于:操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。
处理器31通过运行存储在存储器32中的计算机程序,从而执行各种功能应用以及数据处理,例如本发明实施例1的多标签图像分类模型的训练方法以及实施例2的多标签图像分类方法。
电子设备30也可以与一个或多个外部设备34通信。这种通信可以通过输入/输出(I/O)接口35进行。并且,模型生成的电子设备30还可以通过网络适配器36与一个或者多个网络通信。如图11所示,网络适配器36通过总线33与模型生成的电子设备30的其它模块通信。应当明白,尽管图11未标示,可以结合模型生成的电子设备30使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理器、外部磁盘驱动阵列、RAID(磁盘阵列)系统、磁带驱动器以及数据备份存储系统等。
应当注意,尽管在上文详细描述中提及了电子设备的若干单元/模块或子单元/模块,但是这种划分仅仅是示例性的并非强制性的。实际上,根据本发明的实施方式,上文描述的两个或更多单元/模块的特征和功能可以在一个单元/模块中具体化。反之,上文描述的一个单元/模块的特征和功能可以进一步划分为由多个单元/模块来具体化。
实施例6
本实施例提供了一种计算机可读存储介质,其上存储有计算机程序,所述程序被处理器执行时实现实施例1的多标签图像分类模型的训练方法以及实施例2的多标签图像分类方法。
其中,可读存储介质可以采用的更具体可以包括但不限于:便携式盘、硬盘、随机存取存储器、只读存储器、可擦拭可编程只读存储器、光存储器件、磁存储器件或上述的任意合适的组合。
在可选的一种实施方式中,本发明还可以实现为一种程序产品的形式,其包括程序代码,当所述程序产品在终端设备上运行时,所述程序代码用于使所述终端设备执行实现实施例1的多标签图像分类模型的训练方法以及实施例2的多标签图像分类方法。
其中,可以以一种或多种程序设计语言的任意组合来编写用于执行本发明的程序代码,所述程序代码可以完全地在用户设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户设备上部分在远程设备上执行或完全在远程设备上执行。
虽然以上描述了本发明的具体实施方式,但是本领域的技术人员应当理解,这仅是举例说明,本发明的保护范围是由所附权利要求书限定的。本领域的技术人员在不背离本发明的原理和实质的前提下,可以对这些实施方式做出多种变更或修改,但这些变更和修改均落入本发明的保护范围。

Claims (11)

1.一种多标签图像分类模型的训练方法,其特征在于,所述多标签图像分类模型的训练方法包括:
获取多个训练样本;所述训练样本包括样本图像和与所述样本图像对应的标签集,所述标签集包括至少两个分类标签;
从所述分类标签中获取标签集的语义特征;
将所述样本图像和所述标签集的语义特征输入多标签图像分类模型,所述多标签图像分类模型基于视觉编码器从所述样本图像中获取视觉特征,并基于对比学习网络对所述视觉特征和所述语义特征进行处理得到所述样本图像和所述标签集的相关度,以及基于全连接网络对所述视觉特征进行处理得到所述样本图像映射到所述分类标签的预测结果;
根据所述相关度构建相关度矩阵和与所述相关度矩阵对应的单位矩阵;
基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的分类标签计算目标损失,并根据所述目标损失调整所述多标签图像分类模型的参数,直至满足收敛条件;
所述基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的标签集计算目标损失的步骤包括:
基于所述相关度矩阵和所述单位矩阵计算第一损失;
基于所述预测结果和对应的分类标签计算第二损失;
根据所述第一损失和所述第二损失确定所述目标损失;
所述第一损失包括行损失和列损失,所述基于所述相关度矩阵和所述单位矩阵计算第一损失的步骤包括:
将所述相关度矩阵和所述单位矩阵代入第一损失函数,得到所述行损失;所述第一损失函数包括 其中Lx表示行损失,Mij表示相关度矩阵的第i行第j列,eij表示单位矩阵的第i行第j列,σ()表示激活函数,N表示样本图像的数量;
将所述相关度矩阵和所述单位矩阵代入第二损失函数,得到所述列损失;所述第二损失函数包括 其中Ly表示行损失,Mji表示相关度矩阵的第j行第i列,eji表示单位矩阵的第j行第i列,σ()表示激活函数,N表示样本图像的数量;
根据所述行损失和所述列损失确定所述第一损失。
2.一种多标签图像分类模型的训练方法,其特征在于,所述多标签图像分类模型的训练方法包括:
获取多个训练样本;所述训练样本包括样本图像和与所述样本图像对应的标签集,所述标签集包括至少两个分类标签;
从所述分类标签中获取标签集的语义特征;
将所述样本图像和所述标签集的语义特征输入多标签图像分类模型,所述多标签图像分类模型基于视觉编码器从所述样本图像中获取视觉特征,并基于对比学习网络对所述视觉特征和所述语义特征进行处理得到所述样本图像和所述标签集的相关度,以及基于全连接网络对所述视觉特征进行处理得到所述样本图像映射到所述分类标签的预测结果;
根据所述相关度构建相关度矩阵和与所述相关度矩阵对应的单位矩阵;
基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的分类标签计算目标损失,并根据所述目标损失调整所述多标签图像分类模型的参数,直至满足收敛条件;
所述基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的标签集计算目标损失的步骤包括:
基于所述相关度矩阵和所述单位矩阵计算第一损失;
基于所述预测结果和对应的分类标签计算第二损失;
根据所述第一损失和所述第二损失确定所述目标损失;
所述基于所述预测结果和对应的分类标签计算第二损失的步骤包括:
将所述预测结果和对应的分类标签代入第三损失函数,得到所述第二损失;所述第三损失函数包括 其中,Lv表示第二损失,gkt表示第k个样本图像映射到第t个分类标签的真实标签值,/>表示第k个样本图像映射到第t个分类标签的预测结果值,σ()表示激活函数,N表示样本图像的数量,T表示分类标签的数量。
3.根据权利要求1或2所述的多标签图像分类模型的训练方法,其特征在于,所述视觉编码器包括卷积神经网络和Transformer网络,所述基于视觉编码器从所述样本图像中获取视觉特征的步骤包括:
将所述样本图像输入所述卷积神经网络进行特征提取,得到对应的特征图;
将所述特征图拉平得到第一特征;
将所述第一特征输入所述Transformer网络进行编码处理,得到第二特征;
根据所述第二特征生成所述视觉特征。
4.根据权利要求1或2所述的多标签图像分类模型的训练方法,其特征在于,所述从所述分类标签中获取标签集的语义特征的步骤包括:
分别将所述分类标签输入语义编码器进行语义映射,得到所述分类标签的语义特征;
获取每个分类标签的权重参数;
根据所述权重参数和所述分类标签的语义特征生成所述标签集的语义特征。
5.根据权利要求1所述的多标签图像分类模型的训练方法,其特征在于,所述基于所述预测结果和对应的分类标签计算第二损失的步骤包括:
将所述预测结果和对应的分类标签代入第三损失函数,得到所述第二损失;所述第三损失函数包括 其中,Lv表示第二损失,gkt表示第k个样本图像映射到第t个分类标签的真实标签值,/>表示第k个样本图像映射到第t个分类标签的预测结果值,σ()表示激活函数,N表示样本图像的数量,T表示分类标签的数量。
6.一种多标签图像分类方法,其特征在于,所述多标签图像分类方法包括:
获取待分类图像和至少两个候选标签;
从所述候选标签中获取所述候选标签的语义特征;
将所述待分类图像和所述候选标签的语义特征输入多标签图像分类模型进行分类处理,得到所述待分类图像和所述候选标签的相关度以及所述待分类图像映射到所述候选标签的预测结果;
其中,所述多标签图像分类模型通过如权利要求1-5任一项所述的多标签图像分类模型的训练方法得到;
根据所述相关度和所述预测结果确定所述待分类图像的分类结果。
7.一种多标签图像分类模型的训练装置,其特征在于,所述多标签图像分类模型的训练装置包括:
第一获取模块,用于获取多个训练样本;所述训练样本包括样本图像和与所述样本图像对应的标签集,所述标签集包括至少两个分类标签;
第二获取模块,用于从所述分类标签中获取标签集的语义特征;
第一分类模块,用于将所述样本图像和所述标签集的语义特征输入多标签图像分类模型,所述多标签图像分类模型基于视觉编码器从所述样本图像中获取视觉特征,并基于对比学习网络对所述视觉特征和所述语义特征进行处理得到所述样本图像和所述标签集的相关度,以及基于全连接网络对所述视觉特征进行处理得到所述样本图像映射到所述分类标签的预测结果;
构建模块,用于根据所述相关度构建相关度矩阵和与所述相关度矩阵对应的单位矩阵;
训练模块,用于基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的分类标签计算目标损失,并根据所述目标损失调整所述多标签图像分类模型的参数,直至满足收敛条件;
所述训练模块具体用于基于所述相关度矩阵和所述单位矩阵计算第一损失;基于所述预测结果和对应的分类标签计算第二损失;以及根据所述第一损失和所述第二损失确定所述目标损失;
所述第一损失包括行损失和列损失,所述训练模块具体用于将所述相关度矩阵和所述单位矩阵代入第一损失函数,得到所述行损失;所述第一损失函数包括其中Lx表示行损失,Mij表示相关度矩阵的第i行第j列,eij表示单位矩阵的第i行第j列,σ()表示激活函数,N表示样本图像的数量;将所述相关度矩阵和所述单位矩阵代入第二损失函数,得到所述列损失;所述第二损失函数包括其中Ly表示行损失,Mji表示相关度矩阵的第j行第i列,eji表示单位矩阵的第j行第i列,σ()表示激活函数,N表示样本图像的数量;以及根据所述行损失和所述列损失确定所述第一损失。
8.一种多标签图像分类模型的训练装置,其特征在于,所述多标签图像分类模型的训练装置包括:
第一获取模块,用于获取多个训练样本;所述训练样本包括样本图像和与所述样本图像对应的标签集,所述标签集包括至少两个分类标签;
第二获取模块,用于从所述分类标签中获取标签集的语义特征;
第一分类模块,用于将所述样本图像和所述标签集的语义特征输入多标签图像分类模型,所述多标签图像分类模型基于视觉编码器从所述样本图像中获取视觉特征,并基于对比学习网络对所述视觉特征和所述语义特征进行处理得到所述样本图像和所述标签集的相关度,以及基于全连接网络对所述视觉特征进行处理得到所述样本图像映射到所述分类标签的预测结果;
构建模块,用于根据所述相关度构建相关度矩阵和与所述相关度矩阵对应的单位矩阵;
训练模块,用于基于所述相关度矩阵、所述单位矩阵、所述预测结果和对应的分类标签计算目标损失,并根据所述目标损失调整所述多标签图像分类模型的参数,直至满足收敛条件;
所述训练模块具体用于基于所述相关度矩阵和所述单位矩阵计算第一损失;基于所述预测结果和对应的分类标签计算第二损失;以及根据所述第一损失和所述第二损失确定所述目标损失;
所述训练模块具体用于将所述预测结果和对应的分类标签代入第三损失函数,得到所述第二损失;所述第三损失函数包括 其中,Lv表示第二损失,gkt表示第k个样本图像映射到第t个分类标签的真实标签值,/>表示第k个样本图像映射到第t个分类标签的预测结果值,σ()表示激活函数,N表示样本图像的数量,T表示分类标签的数量。
9.一种多标签图像分类装置,其特征在于,所述多标签图像分类装置包括:
第三获取模块,用于获取待分类图像和至少两个候选标签;
第四获取模块,用于从所述候选标签中获取所述候选标签的语义特征;
第二分类模块,用于将所述待分类图像和所述候选标签的语义特征输入多标签图像分类模型进行分类处理,得到所述待分类图像和所述候选标签的相关度以及所述待分类图像映射到所述候选标签的预测结果;
其中,所述多标签图像分类模型通过如权利要求7或8所述的多标签图像分类模型的训练装置得到;
分类结果确定模块,用于根据所述相关度和所述预测结果确定所述待分类图像的分类结果。
10.一种电子设备,包括存储器、处理器及存储在存储器上并用于在处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1-5中任一项所述的多标签图像分类模型的训练方法或者如权利要求6所述的多标签图像分类方法。
11.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1-5中任一项所述的多标签图像分类模型的训练方法或者如权利要求6所述的多标签图像分类方法。
CN202211626780.1A 2022-12-16 2022-12-16 多标签图像分类方法及其模型的训练方法、装置 Active CN115841596B (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202211626780.1A CN115841596B (zh) 2022-12-16 2022-12-16 多标签图像分类方法及其模型的训练方法、装置
PCT/CN2023/090036 WO2024124770A1 (zh) 2022-12-16 2023-04-23 多标签图像分类方法及其模型的训练方法、装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211626780.1A CN115841596B (zh) 2022-12-16 2022-12-16 多标签图像分类方法及其模型的训练方法、装置

Publications (2)

Publication Number Publication Date
CN115841596A CN115841596A (zh) 2023-03-24
CN115841596B true CN115841596B (zh) 2023-09-15

Family

ID=85578758

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211626780.1A Active CN115841596B (zh) 2022-12-16 2022-12-16 多标签图像分类方法及其模型的训练方法、装置

Country Status (2)

Country Link
CN (1) CN115841596B (zh)
WO (1) WO2024124770A1 (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115841596B (zh) * 2022-12-16 2023-09-15 华院计算技术(上海)股份有限公司 多标签图像分类方法及其模型的训练方法、装置

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109840531A (zh) * 2017-11-24 2019-06-04 华为技术有限公司 训练多标签分类模型的方法和装置
CN111898703A (zh) * 2020-08-14 2020-11-06 腾讯科技(深圳)有限公司 多标签视频分类方法、模型训练方法、装置及介质
CN113343941A (zh) * 2021-07-20 2021-09-03 中国人民大学 一种基于互信息相似度的零样本动作识别方法及系统
CN113449700A (zh) * 2021-08-30 2021-09-28 腾讯科技(深圳)有限公司 视频分类模型的训练、视频分类方法、装置、设备及介质
CN113657425A (zh) * 2021-06-28 2021-11-16 华南师范大学 基于多尺度与跨模态注意力机制的多标签图像分类方法
CN113723513A (zh) * 2021-08-31 2021-11-30 平安国际智慧城市科技股份有限公司 多标签图像分类方法、装置及相关设备
CN114241202A (zh) * 2021-12-17 2022-03-25 携程旅游信息技术(上海)有限公司 着装分类模型的训练方法及装置、着装分类方法及装置
CN114547249A (zh) * 2022-02-24 2022-05-27 济南融瓴科技发展有限公司 一种基于自然语言和视觉特征的车辆检索方法
CN114780719A (zh) * 2022-03-28 2022-07-22 京东城市(北京)数字科技有限公司 文本分类模型的训练方法、文本分类方法及装置

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110163234B (zh) * 2018-10-10 2023-04-18 腾讯科技(深圳)有限公司 一种模型训练方法、装置和存储介质
CN114299340A (zh) * 2021-12-30 2022-04-08 携程旅游信息技术(上海)有限公司 模型训练方法、图像分类方法、系统、设备及介质
CN114821298A (zh) * 2022-03-22 2022-07-29 大连理工大学 一种具有自适应语义信息的多标签遥感图像分类方法
CN115841596B (zh) * 2022-12-16 2023-09-15 华院计算技术(上海)股份有限公司 多标签图像分类方法及其模型的训练方法、装置

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109840531A (zh) * 2017-11-24 2019-06-04 华为技术有限公司 训练多标签分类模型的方法和装置
CN111898703A (zh) * 2020-08-14 2020-11-06 腾讯科技(深圳)有限公司 多标签视频分类方法、模型训练方法、装置及介质
CN113657425A (zh) * 2021-06-28 2021-11-16 华南师范大学 基于多尺度与跨模态注意力机制的多标签图像分类方法
CN113343941A (zh) * 2021-07-20 2021-09-03 中国人民大学 一种基于互信息相似度的零样本动作识别方法及系统
CN113449700A (zh) * 2021-08-30 2021-09-28 腾讯科技(深圳)有限公司 视频分类模型的训练、视频分类方法、装置、设备及介质
CN113723513A (zh) * 2021-08-31 2021-11-30 平安国际智慧城市科技股份有限公司 多标签图像分类方法、装置及相关设备
CN114241202A (zh) * 2021-12-17 2022-03-25 携程旅游信息技术(上海)有限公司 着装分类模型的训练方法及装置、着装分类方法及装置
CN114547249A (zh) * 2022-02-24 2022-05-27 济南融瓴科技发展有限公司 一种基于自然语言和视觉特征的车辆检索方法
CN114780719A (zh) * 2022-03-28 2022-07-22 京东城市(北京)数字科技有限公司 文本分类模型的训练方法、文本分类方法及装置

Also Published As

Publication number Publication date
WO2024124770A1 (zh) 2024-06-20
CN115841596A (zh) 2023-03-24

Similar Documents

Publication Publication Date Title
WO2021022521A1 (zh) 数据处理的方法、训练神经网络模型的方法及设备
CN110197195B (zh) 一种新型面向行为识别的深层网络系统及方法
CN112329760B (zh) 基于空间变换网络端到端印刷体蒙古文识别翻译的方法
KR20190104406A (ko) 처리방법 및 장치
US20230153615A1 (en) Neural network distillation method and apparatus
CN110222760B (zh) 一种基于winograd算法的快速图像处理方法
CN111898703B (zh) 多标签视频分类方法、模型训练方法、装置及介质
CN113128478B (zh) 模型训练方法、行人分析方法、装置、设备及存储介质
CN110245683B (zh) 一种少样本目标识别的残差关系网络构建方法及应用
CN113326851B (zh) 图像特征提取方法、装置、电子设备及存储介质
CN111046771A (zh) 用于恢复书写轨迹的网络模型的训练方法
CN114049515A (zh) 图像分类方法、系统、电子设备和存储介质
CN110781970A (zh) 分类器的生成方法、装置、设备及存储介质
CN115841596B (zh) 多标签图像分类方法及其模型的训练方法、装置
CN109754357B (zh) 图像处理方法、处理装置以及处理设备
CN110991247B (zh) 一种基于深度学习与nca融合的电子元器件识别方法
CN109492610A (zh) 一种行人重识别方法、装置及可读存储介质
Zhou et al. DPNet: Dual-path network for real-time object detection with lightweight attention
CN111368733A (zh) 一种基于标签分布学习的三维手部姿态估计方法、存储介质及终端
CN111444802A (zh) 一种人脸识别方法、装置及智能终端
CN114612681A (zh) 基于gcn的多标签图像分类方法、模型构建方法及装置
CN116797850A (zh) 基于知识蒸馏和一致性正则化的类增量图像分类方法
CN112801153B (zh) 一种嵌入lbp特征的图的半监督图像分类方法及系统
CN115049546A (zh) 样本数据处理方法、装置、电子设备及存储介质
CN114692715A (zh) 一种样本标注方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant