CN115205592A - 一种基于多模态数据的重平衡长尾图像数据分类方法 - Google Patents

一种基于多模态数据的重平衡长尾图像数据分类方法 Download PDF

Info

Publication number
CN115205592A
CN115205592A CN202210829253.4A CN202210829253A CN115205592A CN 115205592 A CN115205592 A CN 115205592A CN 202210829253 A CN202210829253 A CN 202210829253A CN 115205592 A CN115205592 A CN 115205592A
Authority
CN
China
Prior art keywords
data
image
text
class
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210829253.4A
Other languages
English (en)
Inventor
陈东明
赵雨萌
赵文吕
聂铭硕
王冬琦
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Northeastern University China
Original Assignee
Northeastern University China
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Northeastern University China filed Critical Northeastern University China
Priority to CN202210829253.4A priority Critical patent/CN115205592A/zh
Publication of CN115205592A publication Critical patent/CN115205592A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • G06V10/765Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/74Image or video pattern matching; Proximity measures in feature spaces
    • G06V10/75Organisation of the matching processes, e.g. simultaneous or sequential comparisons of image or video features; Coarse-fine approaches, e.g. multi-scale approaches; using context analysis; Selection of dictionaries
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/74Image or video pattern matching; Proximity measures in feature spaces
    • G06V10/761Proximity, similarity or dissimilarity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • Databases & Information Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Medical Informatics (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

本发明属于图像分类领域,设计了一种基于多模态数据的重平衡长尾图像数据分类方法。该方法实现图像‑文本多模态学习在长尾图像分类问题上的应用,旨在使用相对平衡且易于获得、扩展性丰富的文本数据来监督模型对图像特征的学习,通过两个阶段的训练,提高模型在所有种类上的分类效果。第一个阶段使用CLIP大规模预训练模型中的图像和文本编码器,通过对比学习的方法建立两个模态数据的关联性,增强类内图像与文本互信息的同时扩大类间差异性。第二个阶段冻结图像与文本编码器,并在图像编码器后增加了一个多层感知机,使用类平衡采样策略和重平衡损失函数训练少量周期,进一步改善模型对于尾部类的分类能力。

Description

一种基于多模态数据的重平衡长尾图像数据分类方法
技术领域
本发明属于图像分类领域,具体涉及一种基于多模态数据的重平衡长尾图像数据分类方法。
背景技术
图像分类问题是计算机视觉领域的基础问题,旨在根据图像的语义信息将不同类别图像区分开来,实现最小的分类误差。深度学习在图像分类任务上表现良好很大程度上要归功于大规模的高质量训练数据,其中,不同种类包含的样本数量相同,避免了训练样本不均衡带来的负面影响。然而在现实应用中采集到的数据通常呈现为长尾分布,模型难以学习到所有种类的良好特征表示。
长尾数据不均衡给分类带来的问题本质上是最终分类器权重的不均衡问题,目前主流利用信息增强解决,该方法旨在模型训练的过程中引入额外的信息进行辅助,从而提升模型性能。头尾知识迁移利用头部类中的类内方差知识来指导尾部类进行特征增强,使尾部类的特征具备更大的类内方差;模型预训练利用对比学习的方式先进行自监督学习完成预训练,之后再对长尾数据进行正常的训练;知识蒸馏通过一个训练有素的老师模型的输出,去指导训练学生模型;自监督训练使用标记样本训练一个有监督模型,然后利用该模型为未标记样本生成伪标签,最后使用标记样本和未标记样本再次训练模型。
得益于额外引入的知识辅助,这些或迁移或增强数据的方法在提高尾部类分类效果的前提下,并未对头部类造成额外的负面影响,从根源上解决了长尾数据缺乏足够的尾部类信息的问题,是一种值得深入探索的方向。然而,简单使用数据增强技术往往不能有效的划分头部类与尾部类,头部类拥有更多的样本,会得到更多的增强处理,从而进一步加强了信息的不均衡现象。
发明内容
针对现有技术的不足,本发明设计一种基于多模态数据的重平衡长尾图像数据分类方法。
一种基于多模态数据的重平衡长尾图像数据分类方法,具体步骤如下:
步骤1:对图像数据和文本数据进行预处理;
给定一个mini-batch的图像数据I={I1,...,IN}和对应的标签文本数据T={T1,...,TN},其中N为batch size;将mini-batch中属于i类的图像和标签文本两种模态数据表示为
Figure BDA0003747513170000011
Figure BDA0003747513170000012
其中
Figure BDA0003747513170000013
Figure BDA0003747513170000014
为I和T的子集,大小为n;
步骤2:对步骤1得到的图像模态数据
Figure BDA0003747513170000021
和标签文本模态数据
Figure BDA0003747513170000022
进行降维编码;
对于任意
Figure BDA0003747513170000023
将标签文本模态数据套入prompt模板“a photo of a{class}”变成句子并计算token;
Figure BDA0003747513170000024
中的图像与句子token分别送入图像编码器EI和文本编码器ET进行计算,得到图像模态嵌入表示
Figure BDA0003747513170000025
和标签描述模态嵌入表示
Figure BDA0003747513170000026
Figure BDA0003747513170000027
其中
Figure BDA0003747513170000028
D为两个模态Embedding对齐后的输出维度;
步骤3:计算图像模态数据和标签文本模态数据的相似度匹配;
根据步骤2得到的两个模态嵌入表示,通过余弦相似度S判断图像模态数据和标签文本模态数据是否匹配;
Figure BDA0003747513170000029
其中,
Figure BDA00037475131700000210
为属于j类图像模态嵌入表示,
Figure BDA00037475131700000211
为属于k类标签描述模态嵌入表示;
步骤4:对图像模态嵌入表示和标签描述模态嵌入表示进行对比学习预训练,来建立图像模态数据与标签文本模态数据之间种类内部的关联性,同时扩大类间相似性边界,作为第一阶段,即预训练CLIP模型;
步骤4.1:一个mini-batch中正样本个数为n2,为所有同类图像模态数据与标签文本模态数据之间的相似度,负样本个数为N2-n2,为i类图像模态数据与其他不同种类标签文本模态数据之间的相似度,mini-batch的余弦相似度矩阵
Figure BDA00037475131700000212
步骤4.2:对于任一图像模态嵌入表示和标签描述模态嵌入表示,将mini-batch中与其种类相同的对应模态数据所处位置下标编码为1,对不同种类的其所处位置下标设为0,得到一个mini-batch的两种模态数据编码矩阵
Figure BDA00037475131700000213
步骤4.3:计算第一阶段对比学习的损失函数:
Figure BDA00037475131700000214
其中τ为对比学习中的温度系数,设置初始值为0~0.1,并随着训练过程而更新;Si,j是属于i类图像模态嵌入表示和属于j类标签描述模态嵌入表示的余弦相似度;Si,k是属于i类图像模态嵌入表示和属于k类标签描述模态嵌入表示的余弦相似度;
步骤4.4:对CLIP预训练模型中优秀的特征提取能力进行知识蒸馏,使用一个蒸馏损失函数辅助完成知识迁移,以避免训练过程中对CLIP预训练模型造成过拟合现象:
Figure BDA0003747513170000031
其中S′为原始CLIP预训练模型冻结后对相同数据计算而得的余弦相似度;
步骤5:计算最终第一阶段的损失:
Figure BDA0003747513170000032
其中α为超参数,用于调节原始CLIP模型知识蒸馏占模型预训练的比重;
步骤6:重复执行步骤2-5,利用梯度下降算法进行反向传播,更新图像编码器参数,实现第一阶段CLIP模型预训练;
步骤7:任意给定一个大小为N的mini-batch图像模态数据
Figure BDA0003747513170000033
种类数量为C的所有种类标签文本模态数据的句子token为
Figure BDA0003747513170000034
分别经过图像编码器和文本编码器计算后得到嵌入表示
Figure BDA0003747513170000035
Figure BDA0003747513170000036
步骤8:计算步骤7得到的图像嵌入表示fI和标签描述嵌入表示fT的原始余弦相似度:
Sori=fI⊙(fT)·
得到
Figure BDA0003747513170000037
表示第一阶段训练后CLIP模型基于fI和fT,对每个图像种类的预测值;
步骤9:因为图像模态数据呈长尾分布,所以使用图像模态数据和文本模态数据对进行匹配分类仍然不能摆脱失衡问题,因此对CLIP模型进行重平衡以改变图像嵌入表示fI对标签描述嵌入表示fT的敏感程度,作为第二阶段,具体步骤如下:
步骤9.1:fI经过MLP映射后维度不变,再与fT计算相似度:
Smlp=MLP(fI)⊙(fT)·
步骤9.2:将
Figure BDA0003747513170000038
中的余弦相似度加上种类数量权重,得到平衡的余弦相似度:
Figure BDA0003747513170000039
其中,i∈[1,N],
Figure BDA00037475131700000310
μj=nj/n表示第j类样本数目占总数的比例;
步骤9.3:训将
Figure BDA00037475131700000311
与模态数据的one-hot标签使用交叉熵损失函数计算损失,之后进行反向传播,更新MLP参数:
Figure BDA00037475131700000312
其中τ为第一阶段训练后冻结的温度系数;
步骤10:将
Figure BDA0003747513170000041
和Sori加权求和,作为最终预测输出:
Sfinal=λ*Sori+(1-λ)*Sbal
其中λ为超参数,用于调整MLP模块重平衡的权重;
Figure BDA0003747513170000042
代表计算后得到的该图像模态数据对于所有种类描述文本模态数据的匹配程度,即代表预测结果,故argmax(Sfinal)为最终预测种类。
本发明有益技术效果
一种基于多模态数据的重平衡长尾图像数据分类方法,实现图像-文本多模态学习在长尾图像分类问题上的应用。在图像分类中长尾分布的训练数据会导致模型的学习过程容易被样本数据丰富的头部类别主导,对尾部类别的学习建模能力有限,给最终的分类准确率带来挑战。而本发明将模型的特征学习过程与针对长尾问题的重平衡过程解耦作为两个阶段来学习。第一个阶段保持数据集的原始采样策略不变,充分利用所有数据进行图像编码器特征学习,并引入文字模态的特征表示提供监督信息。为进一步改善样本数量稀少的种类的分类性能,消除长尾数据对编码器训练带来的学习偏差,在第二个阶段,冻结图像与文本编码器,并在图像编码器后增加一个多层感知机(Multilayer Perceptron,MLP)用来重新平衡图像编码器。为保留第一阶段学习到的良好特征表示,采用残差连接的思想,结合MLP添加前后模型输出的图像-文本对相似度作为最终的预测值。
本发明不但能够学习到良好的图像特征,而且可以利用图像标签的文字信息辅助引导模型分类,采用对比学习的方式将分类预测问题转换为图像-文本的配对问题,从而实现通过文本数据来监督图像分类训练。
与现有技术相比,本发明提出的技术方案相对其而言计算量大幅减少,该模型对于长尾分布数据集的特征学习能力已经十分接近于均衡数据集。该方法对于种类间的图像-文本对之间的分类边界学习的相当充分,能够从有限的样本数据中学习到类别间的差异性,也体现了标签文本描述的监督对于图像特征提取学习的促进作用。
附图说明
图1本发明CLIP模型第一阶段对比学习预训练框架示意图;
图2本发明CLIP模型第二阶段模型重平衡框架示意图。
具体实施方式
下面结合附图和实施例对本发明做进一步说明;
本发明从图像-文本共同训练的角度出发,探索利用文字信息填补尾部类图像数量稀少带来的信息匮乏缺陷。首先使用对比学习预训练将各类图像与标签文字描述建立关联,最大化类内图像-文本对互信息的同时,扩大种类间的差异。之后针对图像长尾分布的特点引入种类间的图像样本数量信息,增加了一个多层感知机再次训练重平衡图像编码器。
一种基于多模态数据的重平衡长尾图像数据分类方法,具体步骤如下:
步骤1:对图像数据和文本数据进行预处理;
给定一个mini-batch的图像数据I={I1,...,IN}和对应的标签文本数据T={T1,...,TN},其中N为batch size;将mini-batch中属于i类的图像和标签文本两种模态数据表示为
Figure BDA0003747513170000051
Figure BDA0003747513170000052
其中
Figure BDA0003747513170000053
Figure BDA0003747513170000054
为I和T的子集,大小为n;
步骤2:对步骤1得到的图像模态数据
Figure BDA0003747513170000055
和标签文本模态数据
Figure BDA0003747513170000056
进行降维编码;
对于任意
Figure BDA0003747513170000057
将标签文本模态数据套入prompt模板“a photo of a{class}”变成句子并计算token;
Figure BDA0003747513170000058
和的图像与句子token分别送入图像编码器EI和文本编码器ET进行计算,得到图像模态嵌入表示
Figure BDA0003747513170000059
和标签描述模态嵌入表示
Figure BDA00037475131700000510
Figure BDA00037475131700000511
其中
Figure BDA00037475131700000512
D为两个模态Embedding对齐后的输出维度;
步骤3:计算图像模态数据和标签文本模态数据的相似度匹配;
根据步骤2得到的两个模态嵌入表示,通过余弦相似度S判断图像模态数据和标签文本模态数据是否匹配;
Figure BDA00037475131700000513
其中,
Figure BDA00037475131700000514
为属于j类图像模态嵌入表示,
Figure BDA00037475131700000515
为属于k类标签描述模态嵌入表示;
步骤4:对图像模态嵌入表示和标签描述模态嵌入表示进行对比学习预训练,来建立图像模态数据与标签文本模态数据之间种类内部的关联性,同时扩大类间相似性边界,作为第一阶段,即预训练CLIP模型;如附图1所示;
步骤4.1:一个mini-batch中正样本个数为n2,为所有同类图像模态数据与标签文本模态数据之间的相似度,负样本个数为N2-n2,为i类图像模态数据与其他不同种类标签文本模态数据之间的相似度,mini-batch的余弦相似度矩阵
Figure BDA00037475131700000516
步骤4.2:对于任一图像模态嵌入表示和标签描述模态嵌入表示,将mini-batch中与其种类相同的对应模态数据所处位置下标编码为1,对不同种类的其所处位置下标设为0,得到一个mini-batch的两种模态数据编码矩阵
Figure BDA0003747513170000061
步骤4.3:计算第一阶段对比学习的损失函数:
Figure BDA0003747513170000062
其中τ为对比学习中的温度系数,设置初始值为0~0.1,并随着训练过程而更新;Si,j是属于i类图像模态嵌入表示和属于j类标签描述模态嵌入表示的余弦相似度;Si,k是属于i类图像模态嵌入表示和属于k类标签描述模态嵌入表示的余弦相似度;
步骤4.4:对CLIP预训练模型中优秀的特征提取能力进行知识蒸馏,使用一个蒸馏损失函数辅助完成知识迁移,以避免训练过程中对CLIP预训练模型造成过拟合现象:
Figure BDA0003747513170000063
其中S′为原始CLIP预训练模型冻结后对相同数据计算而得的余弦相似度;
步骤5:计算最终第一阶段的损失:
Figure BDA0003747513170000064
其中α为超参数,用于调节原始CLIP模型知识蒸馏占模型预训练的比重;
步骤6:重复执行步骤2-5,利用梯度下降算法进行反向传播,更新图像编码器参数,实现第一阶段CLIP模型预训练;
步骤7:任意给定一个大小为N的mini-batch图像模态数据
Figure BDA0003747513170000065
种类数量为C的所有种类标签文本模态数据的句子token为
Figure BDA0003747513170000066
分别经过图像编码器和文本编码器计算后得到嵌入表示
Figure BDA0003747513170000067
Figure BDA0003747513170000068
步骤8:计算步骤7得到的图像嵌入表示fI和标签描述嵌入表示fT的原始余弦相似度:
Sori=fI⊙(fT)·
得到
Figure BDA0003747513170000069
表示第一阶段训练后CLIP模型基于fI和fT,对每个图像种类的预测值;
步骤9:因为图像模态数据呈长尾分布,所以使用图像模态数据和文本模态数据对进行匹配分类仍然不能摆脱失衡问题,因此对CLIP模型进行重平衡以改变图像嵌入表示fI对标签描述嵌入表示fT的敏感程度,作为第二阶段,如附图2所示,具体步骤如下:
步骤9.1:fI经过MLP映射后维度不变,再与fT计算相似度:
Smlp=MLP(fI)⊙(fT)·
步骤9.2:将
Figure BDA0003747513170000071
中的余弦相似度加上种类数量权重,得到平衡的余弦相似度:
Figure BDA0003747513170000072
其中,i∈[1,N],
Figure BDA0003747513170000073
μj=nj/n表示第j类样本数目占总数的比例;
步骤9.3:训将
Figure BDA0003747513170000074
与模态数据的one-hot标签使用交叉熵损失函数计算损失,之后进行反向传播,更新MLP参数:
Figure BDA0003747513170000075
其中τ为第一阶段训练后冻结的温度系数;
步骤10:将
Figure BDA0003747513170000076
和Sori加权求和,作为最终预测输出:
Sfinal=λ*Sori+(1-λ)*Sbal
其中λ为超参数,用于调整MLP模块重平衡的权重;
Figure BDA0003747513170000077
代表计算后得到的该图像模态数据对于所有种类描述文本模态数据的匹配程度,即代表预测结果,故argmax(Sfinal)为最终预测种类。
本发明使用长尾分类领域CIFAR100数据集、ImageNet2012数据集和Places365数据集。由于原始数据集分布均衡,故采取通用处理方法将其划分为长尾分布数据集。本发明使用Top-1准确率作为主要评价指标,即最终分类器输出向量中概率最大的值所在下标作为模型预测类别的准确率。
Figure BDA0003747513170000078
其中At即Top-1准确率,Au为经验性参考准确率,是基线准确率Av和平衡准确率Ab中的最大值。基线准确率Av表示算法使用的骨干网络在均衡训练集上使用交叉熵损失函数训练后的测试集准确率,平衡准确率Ab表示长尾方法在均衡训练集上训练后的测试集准确率。
本发明的基础网络为CLIP预训练模型,分别使用其公布的ResNet-50和ViT-Base/16两个Encoder作为图像编码器,文本编码器则为CLIP预训练后的GPT-2中的Transformer。
表1所示为CIFAR100-LT数据集上不同算法在CIFAR100-LT数据集3个不平衡比例ρ下的Top-1准确率,对比算法包括了传统类别重平衡(Focal Loss、LDAM等)、信息增强(OLTR、MiSLAS等)和改善模型模块(BBN、RIDE等)等方法、最新的有关图像-文本多模态学习在长尾分类上应用的算法如BALLAD。
表1 CIFAR100-LT不同ρ下的Top-1%准确率;
Figure BDA0003747513170000079
Figure BDA0003747513170000081
可以看到在3个不平衡比例数据集下,该方法均取得了最优效果,相比于同样使用CLIP预训练模型的BALLAD算法,ρ=10、50和100时,ResNet-50图像编码器的Top-1准确率分别提高了6.6%,5.6%和4.8%。而ViT-16图像编码器由于特征学习己相对优秀,故提升较小,Top-1准确率分别提高了1.6%,0.3%和0.5%。需要注意到的是,BALLAD第一阶段没有冻结文本编码器,仍需学习更新文本编码器参数,因此本发明相对其而言计算量大幅减少。
表2所示为ImageNet-LT数据集上各算法的相对精度指标Ar的实验结果。
表2 ImageNet-LT相对精度(%)对比结果;
Figure BDA0003747513170000082
可以看到,BMLTC的相对精度高于90%,说明该模型对于长尾分布数据集的特征学习能力已经十分接近于均衡数据集。但BALLAD和BMLTC模型的相对精度不如RIDE,证明这类迁移算法之所以在长尾分布数据集上分类性能优越,部分原因来源于使用了图像特征提取能力强大的CLIP预训练模型。而RIDE的Au相较于其他方法更高,说明模型取得良好分类效果的原因不仅限于对分类器的调整,虽然也提升了整体的学习泛化能力,但对于长尾分布数据集的提升更为明显,故最终相对精度也更高。此外,信息增强方法的Au相对较高,也说明信息增强方法大都可以提高模型的特征学习能力,而设计特殊损失函数的方法对于长尾分布数据的训练更加贴合。
表3 Places-LT各部分的Top-1准确率(%);
Figure BDA0003747513170000091
表3所示为Places-LT数据集上各算法的对比结果,除了BALLAD和BMLTC外,其他方法均使用ResNet-152作为骨干网络。可以看到,由于这些算法都是基于预训练好的ResNet-152模型再训练的,故在Places-LT这个数据集上各个算法的分类效果差距不是很大。BMLTC在ResNet-50和ViT-16图像编码器中分别相比BALLAD准确率均提升了0.5%和0.6%,在Few-Shot上与BALLAD提升不大甚至略低,而在Many-Shot上提升较多,分别提升了1.2%和1.4%。

Claims (8)

1.一种基于多模态数据的重平衡长尾图像数据分类方法,其特征在于,具体步骤如下:
步骤1:对图像数据和文本数据进行预处理;
给定一个mini-batch的图像数据I={I1,...,IN}和对应的标签文本数据T={T1,...,TN},其中N为batch size;将mini-batch中属于i类的图像和标签文本两种模态数据表示为
Figure FDA0003747513160000011
和Ti +,其中
Figure FDA0003747513160000012
和Ti +为I和T的子集,大小为n;
步骤2:对步骤1得到的图像模态数据
Figure FDA0003747513160000013
和标签文本模态数据Ti +进行降维编码;
步骤3:计算图像模态数据和标签文本模态数据的相似度匹配;
步骤4:对图像模态嵌入表示和标签描述模态嵌入表示进行对比学习预训练,来建立图像模态数据与标签文本模态数据之间种类内部的关联性,同时扩大类间相似性边界,作为第一阶段,即预训练CLIP模型;
步骤5:计算最终第一阶段的损失;
步骤6:重复执行步骤2-5,利用梯度下降算法进行反向传播,更新图像编码器参数,实现第一阶段CLIP模型预训练;
步骤7:任意给定一个大小为N的mini-batch图像模态数据
Figure FDA0003747513160000014
种类数量为C的所有种类标签文本模态数据的句子token为
Figure FDA0003747513160000015
分别经过图像编码器和文本编码器计算后得到嵌入表示
Figure FDA0003747513160000016
Figure FDA0003747513160000017
步骤8:计算步骤7得到的图像嵌入表示fI和标签描述嵌入表示fT的原始余弦相似度Sori
步骤9:因为图像模态数据呈长尾分布,所以使用图像模态数据和文本模态数据对进行匹配分类仍然不能摆脱失衡问题,因此对CLIP模型进行重平衡以改变图像嵌入表示fI对标签描述嵌入表示fT的敏感程度,作为第二阶段,得到
Figure FDA0003747513160000018
步骤10:将
Figure FDA0003747513160000019
和Sori加权求和,作为最终预测输出。
2.根据权利要求1所述的一种基于多模态数据的重平衡长尾图像数据分类方法,其特征在于,步骤2具体为:
对于任意
Figure FDA00037475131600000110
Tj∈Ti +,将标签文本模态数据套入prompt模板“a photo of a{class}”变成句子并计算token;
Figure FDA00037475131600000111
中的图像与句子token分别送入图像编码器EI和文本编码器ET进行计算,得到图像模态嵌入表示fj I和标签描述模态嵌入表示fj T
Figure FDA00037475131600000112
其中
Figure FDA0003747513160000021
D为两个模态Embedding对齐后的输出维度。
3.根据权利要求1所述的一种基于多模态数据的重平衡长尾图像数据分类方法,其特征在于,步骤3具体为:
根据步骤2得到的两个模态嵌入表示,通过余弦相似度S判断图像模态数据和标签文本模态数据是否匹配;
Figure FDA0003747513160000022
其中,fj I为属于j类图像模态嵌入表示,
Figure FDA0003747513160000023
为属于k类标签描述模态嵌入表示。
4.根据权利要求1所述的一种基于多模态数据的重平衡长尾图像数据分类方法,其特征在于,步骤4具体为:
步骤4.1:一个mini-batch中正样本个数为n2,为所有同类图像模态数据与标签文本模态数据之间的相似度,负样本个数为N2-n2,为i类图像模态数据与其他不同种类标签文本模态数据之间的相似度,mini-batch的余弦相似度矩阵
Figure FDA0003747513160000024
步骤4.2:对于任一图像模态嵌入表示和标签描述模态嵌入表示,将mini-batch中与其种类相同的对应模态数据所处位置下标编码为1,对不同种类的其所处位置下标设为0,得到一个mini-batch的两种模态数据编码矩阵
Figure FDA0003747513160000025
步骤4.3:计算第一阶段对比学习的损失函数:
Figure FDA0003747513160000026
其中τ为对比学习中的温度系数,设置初始值为0~0.1,并随着训练过程而更新;Si,j是属于i类图像模态嵌入表示和属于j类标签描述模态嵌入表示的余弦相似度;Si,k是属于i类图像模态嵌入表示和属于k类标签描述模态嵌入表示的余弦相似度;
步骤4.4:对CLIP预训练模型中优秀的特征提取能力进行知识蒸馏,使用一个蒸馏损失函数辅助完成知识迁移,以避免训练过程中对CLIP预训练模型造成过拟合现象:
Figure FDA0003747513160000027
其中S′为原始CLIP预训练模型冻结后对相同数据计算而得的余弦相似度。
5.根据权利要求1所述的一种基于多模态数据的重平衡长尾图像数据分类方法,其特征在于,步骤5计算第一阶段最终损失具体为:
Figure FDA0003747513160000031
其中α为超参数,用于调节原始CLIP模型知识蒸馏占模型预训练的比重。
6.根据权利要求1所述的一种基于多模态数据的重平衡长尾图像数据分类方法,其特征在于,步骤8原始余弦相似度具体为:
Sori=fI⊙(fT)·
得到
Figure FDA0003747513160000032
表示第一阶段训练后CLIP模型基于fI和fT,对每个图像种类的预测值。
7.根据权利要求1所述的一种基于多模态数据的重平衡长尾图像数据分类方法,其特征在于,步骤9具体为:
步骤9.1:fI经过MLP映射后维度不变,再与fT计算相似度:
Smlp=MLP(fI)⊙(fT)·
步骤9.2:将
Figure FDA0003747513160000033
中的余弦相似度加上种类数量权重,得到平衡的余弦相似度:
Figure FDA0003747513160000034
其中,i∈[1,N],
Figure FDA0003747513160000035
μj=nj/n表示第j类样本数目占总数的比例;
步骤9.3:训将
Figure FDA0003747513160000036
与模态数据的one-hot标签使用交叉熵损失函数计算损失,之后进行反向传播,更新MLP参数:
Figure FDA0003747513160000037
其中τ为第一阶段训练后冻结的温度系数。
8.根据权利要求1所述的一种基于多模态数据的重平衡长尾图像数据分类方法,其特征在于,最终预测输出具体为:
Sfinal=λ*Sori+(1-λ)*Sbal
其中λ为超参数,用于调整MLP模块重平衡的权重;
Figure FDA0003747513160000038
代表计算后得到的该图像模态数据对于所有种类描述文本模态数据的匹配程度,即代表预测结果,故argmax(Sfinal)为最终预测种类。
CN202210829253.4A 2022-07-15 2022-07-15 一种基于多模态数据的重平衡长尾图像数据分类方法 Pending CN115205592A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210829253.4A CN115205592A (zh) 2022-07-15 2022-07-15 一种基于多模态数据的重平衡长尾图像数据分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210829253.4A CN115205592A (zh) 2022-07-15 2022-07-15 一种基于多模态数据的重平衡长尾图像数据分类方法

Publications (1)

Publication Number Publication Date
CN115205592A true CN115205592A (zh) 2022-10-18

Family

ID=83581993

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210829253.4A Pending CN115205592A (zh) 2022-07-15 2022-07-15 一种基于多模态数据的重平衡长尾图像数据分类方法

Country Status (1)

Country Link
CN (1) CN115205592A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115410059A (zh) * 2022-11-01 2022-11-29 山东锋士信息技术有限公司 基于对比损失的遥感图像部分监督变化检测方法及设备
CN115829058A (zh) * 2022-12-23 2023-03-21 北京百度网讯科技有限公司 训练样本处理方法、跨模态匹配方法、装置、设备和介质
CN115830006A (zh) * 2023-02-03 2023-03-21 山东锋士信息技术有限公司 一种基于近邻对比的改进超球空间学习的异常检测方法
CN115908949A (zh) * 2023-01-06 2023-04-04 南京理工大学 基于类平衡编码器的长尾图像识别方法
KR102622435B1 (ko) * 2023-04-11 2024-01-08 고려대학교산학협력단 텍스트를 활용한 도메인 비특이적인 이미지 분류 장치 및 방법

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115410059A (zh) * 2022-11-01 2022-11-29 山东锋士信息技术有限公司 基于对比损失的遥感图像部分监督变化检测方法及设备
CN115410059B (zh) * 2022-11-01 2023-03-24 山东锋士信息技术有限公司 基于对比损失的遥感图像部分监督变化检测方法及设备
CN115829058A (zh) * 2022-12-23 2023-03-21 北京百度网讯科技有限公司 训练样本处理方法、跨模态匹配方法、装置、设备和介质
CN115829058B (zh) * 2022-12-23 2024-04-23 北京百度网讯科技有限公司 训练样本处理方法、跨模态匹配方法、装置、设备和介质
CN115908949A (zh) * 2023-01-06 2023-04-04 南京理工大学 基于类平衡编码器的长尾图像识别方法
CN115908949B (zh) * 2023-01-06 2023-11-17 南京理工大学 基于类平衡编码器的长尾图像识别方法
CN115830006A (zh) * 2023-02-03 2023-03-21 山东锋士信息技术有限公司 一种基于近邻对比的改进超球空间学习的异常检测方法
KR102622435B1 (ko) * 2023-04-11 2024-01-08 고려대학교산학협력단 텍스트를 활용한 도메인 비특이적인 이미지 분류 장치 및 방법

Similar Documents

Publication Publication Date Title
CN110298037B (zh) 基于增强注意力机制的卷积神经网络匹配的文本识别方法
CN115205592A (zh) 一种基于多模态数据的重平衡长尾图像数据分类方法
CN110490239B (zh) 图像质控网络的训练方法、质量分类方法、装置及设备
CN102314614B (zh) 一种基于类共享多核学习的图像语义分类方法
CN110490242B (zh) 图像分类网络的训练方法、眼底图像分类方法及相关设备
CN112256866B (zh) 一种基于深度学习的文本细粒度情感分析算法
CN109948696A (zh) 一种多语言场景字符识别方法及系统
CN113673254A (zh) 基于相似度保持的知识蒸馏的立场检测方法
CN112434686B (zh) 针对ocr图片的端到端含错文本分类识别仪
CN113657115A (zh) 一种基于讽刺识别和细粒度特征融合的多模态蒙古文情感分析方法
CN112651940A (zh) 基于双编码器生成式对抗网络的协同视觉显著性检测方法
CN116579345B (zh) 命名实体识别模型的训练方法、命名实体识别方法及装置
CN112883931A (zh) 基于长短期记忆网络的实时真假运动判断方法
CN111930981A (zh) 一种草图检索的数据处理方法
CN116342942A (zh) 基于多级域适应弱监督学习的跨域目标检测方法
CN116579347A (zh) 一种基于动态语义特征融合的评论文本情感分析方法、系统、设备及介质
CN116246279A (zh) 一种基于clip背景知识的图文特征融合方法
CN116049367A (zh) 一种基于无监督知识增强的视觉-语言预训练方法及装置
CN114662456A (zh) 基于Faster R-卷积神经网络检测模型的图像古诗生成方法
CN113920379A (zh) 一种基于知识辅助的零样本图像分类方法
CN114462466A (zh) 一种面向深度学习的数据去偏方法
CN113722439A (zh) 基于对抗性类别对齐网络的跨领域情感分类方法及系统
CN116662924A (zh) 基于双通道与注意力机制的方面级多模态情感分析方法
CN113139464B (zh) 一种电网故障检测方法
CN115577072A (zh) 一种基于深度学习的短文本情感分析方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination