CN116524242A - 基于监督对比学习和多任务设置的小样本图像分类优化方法 - Google Patents
基于监督对比学习和多任务设置的小样本图像分类优化方法 Download PDFInfo
- Publication number
- CN116524242A CN116524242A CN202310366976.XA CN202310366976A CN116524242A CN 116524242 A CN116524242 A CN 116524242A CN 202310366976 A CN202310366976 A CN 202310366976A CN 116524242 A CN116524242 A CN 116524242A
- Authority
- CN
- China
- Prior art keywords
- image
- task
- sample
- learning
- small sample
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 60
- 238000005457 optimization Methods 0.000 title claims abstract description 16
- 238000012549 training Methods 0.000 claims abstract description 41
- 238000012512 characterization method Methods 0.000 claims abstract description 4
- 239000000523 sample Substances 0.000 claims description 101
- 239000013598 vector Substances 0.000 claims description 36
- 239000011159 matrix material Substances 0.000 claims description 32
- 230000006870 function Effects 0.000 claims description 15
- 230000008569 process Effects 0.000 claims description 11
- 239000013074 reference sample Substances 0.000 claims description 9
- 238000013507 mapping Methods 0.000 claims description 8
- 238000004364 calculation method Methods 0.000 claims description 7
- 230000014509 gene expression Effects 0.000 claims description 7
- 238000013528 artificial neural network Methods 0.000 claims description 6
- 239000003795 chemical substances by application Substances 0.000 claims description 4
- 238000010606 normalization Methods 0.000 claims description 4
- 238000001514 detection method Methods 0.000 claims description 3
- 238000007781 pre-processing Methods 0.000 claims description 3
- 238000005070 sampling Methods 0.000 claims description 3
- 238000011161 development Methods 0.000 abstract description 5
- 230000000694 effects Effects 0.000 abstract description 5
- 230000008901 benefit Effects 0.000 abstract description 4
- 230000007547 defect Effects 0.000 abstract description 4
- 238000013473 artificial intelligence Methods 0.000 description 5
- 238000011160 research Methods 0.000 description 5
- 238000013459 approach Methods 0.000 description 4
- 230000018109 developmental process Effects 0.000 description 4
- 238000010586 diagram Methods 0.000 description 4
- 230000009466 transformation Effects 0.000 description 4
- JJWKPURADFRFRB-UHFFFAOYSA-N carbonyl sulfide Chemical compound O=C=S JJWKPURADFRFRB-UHFFFAOYSA-N 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 238000002474 experimental method Methods 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 238000012546 transfer Methods 0.000 description 3
- 101100455978 Arabidopsis thaliana MAM1 gene Proteins 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 238000005304 joining Methods 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000004519 manufacturing process Methods 0.000 description 2
- 230000007246 mechanism Effects 0.000 description 2
- 230000011218 segmentation Effects 0.000 description 2
- 230000003044 adaptive effect Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 210000004556 brain Anatomy 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000007405 data analysis Methods 0.000 description 1
- 230000002950 deficient Effects 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 238000007689 inspection Methods 0.000 description 1
- 238000000691 measurement method Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000008520 organization Effects 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 238000012706 support-vector machine Methods 0.000 description 1
- 238000000844 transformation Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/762—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using clustering, e.g. of similar faces in social networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- General Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于监督对比学习和多任务设置的小样本图像分类优化方法,该方法具体包括:特征提取器训练、代理任务和小样本分类等阶段,特征提取器训练阶将类别监督信号引入对比学习参与图像表征训练;代理任务阶替换为更加适应小样本学习的背景混淆方法;小样本分类阶段在余弦分类器的基础上加入了多任务设置,缓解了元学习训练带来的参数优化不稳定的弊端。本发明与现有技术相比具有引导特征提取器学习更多语义信息,增强元知识的一般性,提高了小样本学习在一般的图像分类任务的准确度,方法简便,使用效果好,有较高的实用价值与良好的发展前景。
Description
技术领域
本发明涉及小样本图像分类技术领域,尤其是一种基于监督对比学习和多任务的图像分类优化方法。
背景技术
数据驱动引导的人工智能取得了非常瞩目的成就,几乎涵盖了整个计算机视觉、自然语言处理、音频和语音、强化学习和数据分析等多个热点研究领域。然而,很多场景下由于诸如隐私、安全和成本等原因,往往无法获得大规模标注过的训练样本。当只有有限的样本时,常规的模型很难在不同的数据分布中进行推广。以智能制造场景下的工业检测为例,模型泛化性差已成为影响其生产效率的关键挑战,特别是在照明条件的经常发生变化和缺陷样本稀缺的情况下。类似的问题也存在一些高精尖的行业。例如在寻找高能宇宙射线的任务中,研究人员希望从密密麻麻的宇宙射线中找到携带特殊能量的宇宙射线。整个过程将会非常耗时,且成本高昂。如今,小样本学习在生活的各行各业已经非常活跃,包括信用卡欺诈、票据识别、意图识别、冷启动推荐和手势识别。
研究小样本学习对人工智能的发展有着非常重要的现实意义。2003年,小样本学习的概念首次被李飞飞提出就引发了业界广泛讨论。它颠覆了传统机器学习对于数据和知识的主观认知。全新的类别,极其有限的数据集成了小样本学习最大的特点。国外对于小样本学习的研究起步较早,且成果较为丰富。2017年,多伦多大学联合著名的美国社交通讯巨头Twitter发表了一篇奠基性的研究成果。小样本学习并不需要对模型进行特殊设计,特征提取器和一个基于度量学习的分类器即可达到不错的结果。这种简单的思想影响一大批小样本学习工作在此基础上取得积极的进展。同年,OpenAI提出了元学习范式,试图在任务层面帮助小样本学习取得突破。元学习的提出,使得小样本学习在预训练阶段不再需要大量的基类数据。随后,在2018年中国举办的首届世界人工智能大会上,中国科学院张钹做了主题为《基于大数据的人工智能》的报告,报告指出:“深度学习对于数据的利用率太低,导致我们需要大量数据去训练它,我们要想彻底解决小样本学习,必须增大对单个数据的利用率”。同样的,Hugo Larochelle作为谷歌大脑首席科学家在2019年深度学习峰会上指出,“小样本学习其独特的学习方式,必然成为深度学习领域的下一个热点”。解决小样本学习是机器学习通往通用人工智能的必要途径。小样本学习的概念受到人类强大的推理和分析能力的启发,旨在从少量样本中学习和概括全新概念。
目前,小样本学习的基线方法可以大致分为度量学习和元学习。度量学习需要一个足够灵活的特征提取器,它负责将基类数据集中学到的特征表达最大限度地转移到下游任务中。度量学习常用的方法包括余弦相似度、欧氏距离、关系网络和图神经网络。元学习提出情景再现的训练方式。训练阶段和测试阶段使用相同的任务组织形式帮助模型提取到有用的规则。元学习可以很好的处理不同任务之间的差异,从而忽略具体任务的特征。国内外关于小样本学习在度量学习和元学习方面有很多代表性的工作。其中原型网络因其思想简单、结果有效而受到广泛欢迎。它通过计算支持集中特征向量的平均值作为原型,使用欧氏距离比较查询样本和原型之间距离的远近。随后,Baseline++进一步将特征向量归一化,计算输入特征与权重矩阵在标准球面的余弦相似度。关系网络更进一步抽象了相似度计算,设计一个专门的模型用来直接输出结果。类似的工作还包括图神经网络。在图神经网络中,先验知识通过消息传播机制,将支持集的特征和标签传播到查询集。元学习方面,MAML提出了任务场景下情景再现的训练机制。此后的一些工作在MAML的基础上进行了一系列的改进,提高了任务空间参数初始化的能力;MTFL运用元学习策略,将DNN微调所需的任务数下降到只有8000个;MetaOptNet更进一步使用线性支持向量机作为内循环改善特征向量的表示。最近,Meta Baseline首次全面探索了元学习和全分类学习器之间的优势,提出了小样本学习领域全新的基线模型。
现有的技术用来解决小样本学习存在不同程度的问题,首先是特征提取器的问题,传统的监督学习学习到的知识均和标签相关,对于没有见过的类别缺乏比较泛化的底层的特征表达,这就极大的制约了小样本学习的发展。大多数工作都是基于迁移学习展开,尝试对特征提取器固定一部分参数,从而微调另一部分的参数。实验结果表明这种方法对于领域相差不大的数据集来说,效果较为明显。还有一种方案是基于多任务的学习来弥补小样本学习自身数据不足的窘境。通过收集不同类型的相关任务,损失函数迫使不同任务的优化均同步进行,这在一定程度上不会使得参数刻意的逼近任何一个任务,结果就会使得额外的小样本任务也表现较好。简单的多任务学习对小样本学习提升效果有限,结合元学习能更加释放这种优势。但是元学习的存在会出现参数优化不稳定的问题,测试集结果受采样的任务波动较大的情况。
发明内容
本发明的目的是针对现有技术的不足而提供的一种基于监督对比学习和多任务设置的小样本图像分类优化方法,采用标签掩码矩阵和背景混淆方法,基于小样本元基线,以监督学习和多任务为切入点,通过多任务学习,引导特征提取器学习更多语义信息,增强元知识的一般性,提高了小样本学习在一般的图像分类任务的准确度,方法简便,使用效果好,有较高的实用价值与良好的发展前景。
实现本发明目的的具体技术方案是:一种基于监督对比学习和多任务设置的小样本图像分类优化方法,其特点是该方法主要包括:
1)采用标签掩码矩阵和背景混淆方法,分别应用于实例对比学习及其代理任务,监督对比学习区别于实例对比学习最大的特点就是正样本不仅包含目标样本经过背景混淆的样本还包括同类别的其他样本,标签掩码矩阵和背景混淆方法在mini-imgeNet数据集实现了基线方法中最好的结果。
2)采用在小样本分类器基础上增加多任务设置,具体包括:预测图像旋转角度、预测图像补丁的相对位置和图像嵌入空间聚类任务,辅助任务在一定程度上可以帮助特征提取器学习更加底层的语义信息。在小样本分类阶段,相关度量方法可以更好的计算嵌入空间中查询样本和支持样本之间的距离。大量的实验证明,图像嵌入聚类的辅助任务在各项基准数据集中均表现最好。
本发明具体包括下述三个阶段:
S1、特征提取器训练阶段
借助标签掩码矩阵M记录每个批次内样本标签的异同,并将标签掩码矩阵M作为类别监督信号引入实例对比学习参与图像表征训练。
所述步骤S1的特征提取器训练具体包括下述步骤:
S1-1:记录每个批次内样本标签的异同
标签掩码矩阵M记录每个批次内样本标签的异同,标签掩码矩阵M的大小为[bsz,bsz],其中bsz为批次内数据点的长度,所述标签掩码矩阵M里的元素均由1和0构成,如果样本i和j具有相同的标签,标签掩码矩阵对应的值为1,反之则为0,即mask{i,j}=1,mask{i,j}=0。
S1-2:监督对比学习
将正样本对的范围扩充到相同类别的其他样本中,使用监督对比损,得到特征提取器F,在损失函数的计算过程中,只需要找出基准样本所在行或者列元素值为1的编号即可,所述监督对比损失Lsup由下述(a)式计算:
式中,P(i)为与基准样本标签一致的其他样本索引集合;Fi是样本i对应的特征向量,同理,Fp和Fa分别对应样本p和a对应的特征向量;|P(i)|为与基准样本标签一致的其他样本索引集合的基数;A(i)为除相同类别外的其他负样本索引;γ为标量-温度参数;
通过监督对比损失,得到一个足够灵活的特征提取器F,在小样本分类阶段,支持集Csupport会被送入特征提取器F获得相应的类别原型。如果5-way 1-shot任务,支持集Csupport的特征向量直接作为类别原型,如果任务是5-way 5-shot任务,按下述(c)式计算计算5个特征向量的平均值作为原型:
式中,x为支持集Csupport中的一个样本点;|Csupport|是支持集的基数;θ是特征提取器F的参数;Wn为每个任务中的原型向量。
特征提取器F和余弦相似度的分类器一起在基类数据集上进行训练。Wb=[W1,W2,……,Wn]表示支持集中每个类别的d维原型向量。需要注意的是,余弦相似度的p(y=c|x)的计算需要对特征向量进行归一化处理由下述(d)式表示为:
式中,<>表示两个向量在归一化之后之间的余弦相似度运算;Fθ(x)为样本x对应的特征向量,同理,fθ(x)也是其对应的特征向量;wn原型矩阵中除了目标原型wn'的元素。
在小样本分类任务中,查询集样本和原型的相似度得分均在-1和1之间,数值越接近1,两个向量在嵌入空间中就越相似。反之,两个向量的差异越大。在预测阶段,查询集中的图像会被归入与之最接近的原型中。原型的标签即为查询集中图像的标签。
S2:代理任务阶段
评估传统的代理任务例如数据增强,视图转化对实例对比学习的贡献,设计替换为更加适应小样本学习的背景混淆方法,背景裁剪方法主要由显著性检测和背景裁剪两部分组成。背景混淆方法它不同于传统的数据增强和视图转化,背景混淆可以增强图像的前景表达。假设整个基类数据集为Cbase。整个训练过程中,数据集被组织成N-way K-shot的形式,其中支持集Csupport有N类别,每个类别有K个样本。查询集Cquery表示的是需要分类的样本。Csupport中每一个张图片都可以将前景转移到其他类别的图片上形成增强样本。具体的,背景混淆方法包括两个主要部分:
1)显著性映射模块:显著性映射模块冻结了预训练模型除输入层以外的预训练权重,计算标签得分相对于输入图像的梯度。在获得输入图像中每个像素的显著性得分后,搜索所有像素中得分最高的一个点作为原点,计算人工设定区域内显著性得分的平均值,如果大于前景阈值这个区域就被认为找到了。反之,如果得分均值小于阈值,需要继续进行搜索。
2)背景混淆模块:对于基类数据集参与训练的每个批次数据,采样输入数据点和其他类别中随机抽取另一个数据点,利用显著性映射模块提取前景区域。借助矩阵裁剪运算,将这两个数据点的的前景进行交换,得到新的背景混淆的增强样本。
S3、小样本分类阶段
在余弦分类器的基础上加入了多任务设置,缓解了元学习训练带来的参数优化不稳定的弊端。小样本分类阶段使用多任务设置,辅助小样本分类任务探索更加丰富的语义信息。
在多任务设置下的小样本分类共涉及两条优化路径:一条路径基于余弦相似度,参数表示为θ。损失函数定义为Lfew。另一条路径为辅助任务,参数表示为额外的损失为Laux。特别指出,在预测图像旋转角度和图像相对位置的任务中,小样本分类任务和辅助任务的数据集数量保持不变。在嵌入空间聚类的辅助任务中,支持集数据集还会经过一个前景增强模块,样本数量变为之前的两倍。整个多任务设置的损失函数Ltotal由下述(b)式表示为:
式中,α为Lfew和Laux两个损失函数的正则化超参数加权系数。Wb代表原型特征向量的集合。θ和分别是特征提取器和辅助任务的参数。
多任务学习的每一次参数更新均涉及共享的特征提取器F的参数和各自分类头的参数。下面具体介绍多任务的各项工作。
1)预测图像旋转角度任务:在辅助任务中我们对支持集中的图像进行了相应的预处理。具体做法是将图像进行有规则的旋转并记录相应的标签。为了方便训练,分别对每个图像进行了0度、90度、180度和270度的旋转。由于不需要额外的成本进行标注,预处理产生的计算量可以忽略不计。经过旋转增强后的样本需要经过特征提取器变成特征向量,再进入专门设计的角度预测网络R输出对应的标签。考虑到图像旋转角度预测任务的简单性,这里使用一个多层全连接神经网络进行实现,超参数α设置为1。
2)预测图像补丁的相对位置:常规操作就是将每支持集图像等分为9份,每一个补丁都基于中心位置进行了位置编码,分别为左上,中上,右上,左,右,左下,中下和右下。在实验中,通过对除中心位置以外的8个可能的位置都可以进行相对位置预测,可能的结果包含8种,但是正确的结果只有一种。考虑到预测图像块相对位置的任务所含的可能性更多,采用性能更高的余弦分类器进行预测。图像补丁被分的越细,每一块所携带有效的语义信息就越少。辅助任务的难度也就越大,本发明中设置了3x3的切分,超参数α设置为1。
3)图像嵌入聚类:由于以上两类辅助任务在特征提取器阶段并没有参与预训练,加入新的任务需要从头开始调整参数。受阶段一的预训练启发,我们采用监督对比提出图像聚类辅助任务。在嵌入空间中,图像聚类任务可以帮助小样本学习很好的预测新的类别。从损失函数的角度来看,它与小样本本身的分类任务优化方向一致。监督对比损失的分子是所有正样本对在嵌入空间之间的距离,分母为所有负样本对在嵌入空间之间的距离。只有当相同类别的标签不断靠近,不同类别的标签不断原理,图像聚类任务的损失值才会越来越小,多任务设置下总的损失值才会越来越小,超参数α设置为2。
本发明遵循Meta Baseline的学习框架,主要解决了小样本学习在预训练阶段和分类阶段中存在的两个显著问题。在预训练阶段,实例对比学习仅关注自身的属性,导致特征提取器缺乏全局的类别边界信息。通过加入标签掩码矩阵,可以为实例监督对比学习提供监督信号。其次在小样本分类器阶段,元学习框架下的余弦相似度存在参数优化不稳定的问题,测试集结果受采样的任务波动较大的情况,为了缓解这个问题,又提出了多任务的实验设置。综上所述,本发明主要研究了面向监督对比学习的编码器在小样本方向的特征表达,以背景混淆、多任务为切入点开展研究。
本发明与现有技术相比具有基于小样本元基线方法,以监督学习和多任务为切入点,设计了标签掩码矩阵、背景混淆方法,多任务学习,引导特征提取器学习更多语义信息,增强元知识的一般性,提高了小样本学习在一般的图像分类任务的准确度,方法简便,使用效果好,有较高的实用价值与良好的发展前景。
附图说明
图1为本发明架构的示意图;
图2为预测图像旋转角度任务示意图;
图3为预测图像补丁相对位置任务示意图;
图4为图像嵌入聚类任务示意图。
具体实施方式
以下结合附图及实施例对本发明进行详细描述。显然,所列举的实例只用于解释本发明,并非用于限定本发明的范围。
实施例1
参阅图1,本发明具体包括以下步骤:
S1、特征提取器训练阶段,借助标签掩码矩阵M记录每个批次内样本标签的异同,标签掩码矩阵M里面的元素均由1和0构成。如果样本i和j具有相同的标签,标签掩码矩阵M对应的值为1,反之则为0,将标签掩码矩阵M作为类别监督信号引入实例对比学习参与图像表征训练。
所述标签掩码矩阵M将监督对比学习将正样本对的范围扩充到相同类别的其他样本中,在损失函数的计算过程中,只需要找出基准样本所在行或者列元素值为1的编号即可。具体的,监督对比损失由下述(a)式计算:
式中,P(i)为与基准样本标签一致的其他样本索引集合;|P(i)|为与基准样本标签一致的其他样本索引集合的基数;A(i)为除相同类别外的其他负样本索引;Fi是样本i对应的特征向量,同理,Fp和Fa分别对应样本p和a对应的特征向量;γ为标量-温度参数。
通过监督对比损失,得到一个足够灵活的特征提取器F,在小样本分类阶段,支持集Csupport会被送入特征提取器F获得相应的类别原型。如果5-way1-shot任务,支持集Csupport的特征向量直接作为类别原型,如果任务是5-way5-shot任务,按下述(c)式计算计算5个特征向量的平均值作为原型:
式中,x为支持集Csupport中的一个样本点;|Csupport|是支持集的基数;θ是特征提取器F的参数;Wn为代表任务的原型特征向量。
特征提取器F和余弦相似度的分类器一起在基类数据集Cbase上进行训练。Wb=[W1,W2,……,Wn]表示支持集Csupport中每个类别的d维原型向量。需要注意的是,余弦相似度的计算需要对特征向量进行归一化处理,其公式由下述(d)式具体表示为:
式中,<>表示两个向量在归一化之后之间的余弦相似度运算。Fθ(x)和fθ(x)分别代表x对应的特征向量;wn是原型矩阵中除了目标原型wn'的元素。
在小样本分类任务中,查询集Cquery样本和原型的相似度得分均在-1和1之间。数值越接近1,两个向量在嵌入空间中就越相似。反之,两个向量的差异越大。在预测阶段,查询集Cquery中的图像会被归入与之最接近的原型中。原型的标签即为查询集Cquery中图像的标签。
S2:评估传统的代理任务例如数据增强,视图转化对实例对比学习的贡献,设计替换为更加适应小样本学习的背景混淆方法。它不同于传统的数据增强和视图转化。背景混淆可以增强图像的前景表达。
假设整个基类数据集为Cbase,整个训练过程中,数据集被组织成N-way K-shot的形式,其中支持集Csupport有N类别,每个类别有K个样本。查询集Cquery表示的是需要分类的样本。Csupport中每一个张图片都可以将前景转移到其他类别的图片上形成增强样本。具体的,背景混淆方法包括两个主要部分:
1)显著性映射模块:显著性映射模块冻结了预训练模型除输入层以外的预训练权重,计算标签得分相对于输入图像的梯度,在获得输入图像中每个像素的显著性得分后,搜索所有像素中得分最高的一个点作为原点,计算人工设定区域内显著性得分的平均值,如果大于前景阈值这个区域就被认为找到了;反之,如果得分均值小于阈值,需要继续进行搜索。
2)背景混淆模块:对于基类数据集Cbase参与训练的每个批次数据,采样输入数据点和其他类别中随机抽取另一个数据点,利用显著性映射模块提取前景区域。借助矩阵裁剪运算,将这两个数据点的的前景进行交换,得到新的背景混淆的增强样本。
所述显著性映射实现背景混淆算法的伪代码如下述表1所示:
表1伪代码描述如下:
S3、小样本分类阶段,在余弦分类器的基础上加入了多任务设置,缓解了元学习训练带来的参数优化不稳定的弊端,同时辅助小样本分类任务探索更加丰富的语义信息。
参阅图1,多任务设置下的小样本分类共涉及两条优化路径,一路径基于余弦相似度,参数表示为θ,其损失函数定义为Lfew。另一路径为辅助任务,参数表示为额外的损失为Laux。特别指出,在预测图像旋转角度和图像相对位置的任务中,小样本分类任务和辅助任务的数据集数量保持不变。在嵌入空间聚类的辅助任务中,支持集Csupport、基类数据集为Cbase还会经过一个前景增强模块,样本数量变为之前的两倍。整个多任务设置的损失函数由下述(b)式表示为:
式中,α为Lfew和Laux两个损失函数的正则化超参数加权系数;Wb代表原型特征向量的集合;θ和分别是特征提取器和辅助任务的参数。
多任务学习的每一次参数更新均涉及共享的特征提取器F的参数和各自分类头的参数,下面具体介绍多任务的各项工作。
1)预测图像旋转角度任务
参阅如图2,在辅助任务中对支持集Csupport中的图像进行了相应的预处理。具体做法是将图像进行有规则的旋转并记录相应的标签。为了方便训练,分别对每个图像进行了0度、90度、180度和270度的旋转。整个框架如图2所示。由于不需要额外的成本进行标注,预处理产生的计算量可以忽略不计。经过旋转增强后的样本需要经过特征提取器F变成特征向量,再进入专门设计的角度预测网络R输出对应的标签。考虑到图像旋转角度预测任务的简单性,使用一个多层全连接神经网络进行实现,超参数α设置为1。
2)预测图像补丁的相对位置
参阅图3,该图展示了预测图像补丁相对位置的框架,常规操作就是将每支持集图像等分为9份,每一个补丁都基于中心位置进行了位置编码,分别为左上、中上、右上、左、右、左下、中下和右下。在实验中,通过对除中心位置以外的8个可能的位置都可以进行相对位置预测,可能的结果包含8种,但是正确的结果只有一种。考虑到预测图像块相对位置的任务所含的可能性更多,采用性能更高的余弦分类器进行预测。图像补丁被分的越细,每一块所携带有效的语义信息就越少。辅助任务的难度也就越大。本发明中设置了3x3的切分,超参数α设置为1。
3)图像嵌入聚类:由于以上两类辅助任务在特征提取器F阶段并没有参与预训练,加入新的任务需要从头开始调整参数。受阶段一的预训练启发,采用监督对比提出图像聚类辅助任务。
参阅图4,在嵌入空间中,图像聚类任务可以帮助小样本学习很好的预测新的类别。从损失函数的角度来看,它与小样本本身的分类任务优化方向一致。监督对比损失的分子是所有正样本对在嵌入空间之间的距离,分母为所有负样本对在嵌入空间之间的距离。只有当相同类别的标签不断靠近,不同类别的标签不断原理,图像聚类任务的损失值才会越来越小,多任务设置下总的损失值才会越来越小。超参数α设置为2。
以上具体实施只是对本发明做进一步说明,并非用以限制本发明专利,凡为本发明等效实施,均应包含于本发明专利的权利要求范围之内。
Claims (7)
1.一种基于监督对比学习和多任务设置的小样本图像分类优化方法,其特征在于,该方法具体包括下述三个阶段:
S1、特征提取器训练阶段
借助标签掩码矩阵M记录每个批次内样本标签的异同,并将标签掩码矩阵M作为类别监督信号引入实例对比学习参与图像表征训练,所述标签掩码矩阵M里的元素均由1和0构成,如果样本i和j具有相同的标签,标签掩码矩阵对应的值为1,反之则为0;
S2:代理任务阶段
将标签掩码矩阵M和背景混淆方法,分别应用于实例对比学习及其代理任务;
S3、小样本分类阶段
在小样本分类器的基础上增加多任务设置进行小样本分类,所述多任务设置具体包括:预测图像旋转角度、预测图像补丁的相对位置和图像嵌入空间聚类。
2.根据权利要求1所述基于监督对比学习和多任务设置的小样本图像分类优化方法,其特征在于,所述步骤S1的特征提取器训练具体包括下述步骤:
S1-1:记录每个批次内样本标签的异同
标签掩码矩阵M记录每个批次内样本标签的异同,标签掩码矩阵M的大小为[bsz,bsz],其中bsz为批次内数据点的长度,如果样本i和j具有相同的标签,标签掩码矩阵M对应的值为1,反之则为0,即mask{i,j}=1,mask{i,j}=0;
S1-2:监督对比学习
将正样本对的范围扩充到相同类别的其他样本中,使用监督对比损,得到特征提取器F,在损失函数的计算过程中,只需找出基准样本所在行或者列元素值为1的编号即可,所述监督对比损失Lsup由下述(a)式计算:
式中,P(i)为与基准样本标签一致的其他样本索引集合;|P(i)|为与基准样本标签一致的其他样本索引集合的基数;A(i)为除相同类别外的其他负样本索引;γ为标量-温度参数。
3.根据权利要求1所述基于监督对比学习和多任务设置的小样本图像分类优化方法,其特征在于,所述步骤S2中的背景混淆方法是增强图像的前景表达,假设整个基类数据集为Cbase,查询集为Cquery,整个训练过程中,Cbase被组织成N-way K-shot的形式,其中支持集Csupport有N类别,每个类别有K个样本,查询集Cquery表示的是需要分类的样本,支持集Csupport中每一个张图片都将前景转移到其他类别的图片上形成增强样本,所述背景混淆方法包括:显著性检测和背景裁剪两部分;所述显著性检测使用显著性映射模块冻结预训练模型除输入层以外的预训练权重,计算标签得分相对于输入图像的梯度,在获得输入图像中每个像素的显著性得分后,搜索所有像素中得分最高的一个点作为原点,计算人工设定区域内显著性得分的平均值,如果大于前景阈值这个区域就被认为找到了;反之,如果得分均值小于阈值,则需继续进行搜索;所述背景裁剪使用背景混淆模块,对于基类数据集Cbase参与训练的每个批次数据,采样输入数据点和其他类别中随机抽取另一个数据点,利用显著性映射模块提取前景区域,借助矩阵裁剪运算,将这两个数据点的的前景进行交换,得到新的背景混淆的增强样本。
4.根据权利要求1所述基于监督对比学习和多任务设置的小样本图像分类优化方法,其特征在于,所述步骤S3中多任务设置下的小样本分类使用基于余弦相似度和辅助任务,所述基于余弦相似度的参数表示为θ,损失函数定义为Lfew;所述辅助任务的参数表示为额外的损失为Laux;所述多任务设置在预测图像旋转角度和图像相对位置的任务中,小样本分类任务和辅助任务的数据集数量保持不变,在嵌入空间聚类的辅助任务中,支持集Dsupport和基类数据集Cbase经前景增强模块使样本数量增大两倍,整个多任务设置的损失函数Ltotal由下述(b)式表示为:
式中,α为Lfew和Laux两个损失函数的正则化超参数加权系数;Wb为每个原型代表的特征向量;
所述多任务学习的每一次参数更新均涉及共享的特征提取器F的参数和各自分类头的参数。
5.根据权利要求1所述基于监督对比学习和多任务设置的小样本图像分类优化方法,其特征在于,所述步骤S3在小样本分类阶段支持集Csupport被送入特征提取器F获得相应的类别原型;如果为5-way 1-shot任务,支持集Csupport的特征向量直接作为类别原型;如果为5-way 5-shot任务,按下述(c)式计算计算5个特征向量的平均值Wn作为原型:
式中,x为支持集Csupport中的一个样本点;|Csupport|为Csupport的基数;θ是特征提取器F的参数;
所述特征提取器F和余弦相似度的分类器一起在基类数据集Cbase上进行训练,Wb=[W1,W2,……,Wn]表示支持集Csupport中每个类别的d维原型向量;
所述余弦相似度的计算需对特征向量进行归一化处理由下述(d)式表示为:
式中,<>表示两个向量在归一化之后之间的余弦相似度运算;wn为原型矩阵中除了目标原型wn'的元素;Fθ(x)和fθ(x)分别为样本x对应的特征向量;查询集Cquery样本和原型的相似度得分均在-1和1之间,数值越接近1,两个向量在嵌入空间中就越相似;反之,两个向量的差异越大;在预测阶段,查询集Cquery中的图像被归入与之最接近的原型中,原型的标签即为查询集Cquery中图像的标签。
6.根据权利要求1或权利要求4所述基于监督对比学习和多任务设置的小样本图像分类优化方法,其特征在于,所述预测图像旋转角度任务在辅助任务中对支持集Csupport中的图像进行预处理,所述预处理是将图像进行0度、90度、180度和270度的旋转,并记录相应的标签,旋转增强后的样本经特征提取器F变为特征向量,并以角度预测网络R输出对应的标签;所述角度预测网络R使用一个多层全连接神经网络,其超参数α设置为1,预测图像补丁的相对位置是将每一支持集Csupport图像等分为9份,每一个补丁都基于中心位置进行了位置编码,分别为左上、中上、右上、左、右、左下、中下和右下,通过对除中心位置以外的8个的位置都进行相对位置预测。
7.根据权利要求1或权利要求4所述基于监督对比学习和多任务设置的小样本图像分类优化方法,其特征在于,所述图像嵌入聚类采用监督对比提出图像聚类辅助任务,监督对比损失的分子是所有正样本对在嵌入空间之间的距离,分母为所有负样本对在嵌入空间之间的距离,只有当相同类别的标签不断靠近,不同类别的标签不断原理,图像聚类任务的损失值才会越来越小,多任务设置下总的损失值才会越来越小,其超参数α设置为2。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310366976.XA CN116524242A (zh) | 2023-04-07 | 2023-04-07 | 基于监督对比学习和多任务设置的小样本图像分类优化方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310366976.XA CN116524242A (zh) | 2023-04-07 | 2023-04-07 | 基于监督对比学习和多任务设置的小样本图像分类优化方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116524242A true CN116524242A (zh) | 2023-08-01 |
Family
ID=87391247
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310366976.XA Pending CN116524242A (zh) | 2023-04-07 | 2023-04-07 | 基于监督对比学习和多任务设置的小样本图像分类优化方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116524242A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117520551A (zh) * | 2024-01-08 | 2024-02-06 | 南京邮电大学 | 一种小样本文本自动分类方法及系统 |
CN118506290A (zh) * | 2024-07-19 | 2024-08-16 | 贵州交建信息科技有限公司 | 基于ai识别的梁场施工安全质量监测方法及系统 |
-
2023
- 2023-04-07 CN CN202310366976.XA patent/CN116524242A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117520551A (zh) * | 2024-01-08 | 2024-02-06 | 南京邮电大学 | 一种小样本文本自动分类方法及系统 |
CN117520551B (zh) * | 2024-01-08 | 2024-05-10 | 南京邮电大学 | 一种小样本文本自动分类方法及系统 |
CN118506290A (zh) * | 2024-07-19 | 2024-08-16 | 贵州交建信息科技有限公司 | 基于ai识别的梁场施工安全质量监测方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Abdar et al. | A review of uncertainty quantification in deep learning: Techniques, applications and challenges | |
Jain et al. | Synthetic data augmentation for surface defect detection and classification using deep learning | |
CN116524242A (zh) | 基于监督对比学习和多任务设置的小样本图像分类优化方法 | |
Du et al. | Semi-siamese training for shallow face learning | |
Jayaraman et al. | Learning image representations tied to egomotion from unlabeled video | |
Zhao et al. | FaNet: Feature-aware network for few shot classification of strip steel surface defects | |
CN113535947B (zh) | 一种带有缺失标记的不完备数据的多标记分类方法及装置 | |
Liu et al. | Fabric defect detection based on lightweight neural network | |
CN117611576A (zh) | 一种基于图文融合对比学习预测方法 | |
Tang et al. | A Siamese network-based tracking framework for hyperspectral video | |
Baharani et al. | Real-time person re-identification at the edge: A mixed precision approach | |
Luo et al. | RBD-Net: robust breakage detection algorithm for industrial leather | |
Deng et al. | Emotion class-wise aware loss for image emotion classification | |
Ge et al. | Deep spatial attention hashing network for image retrieval | |
CN117611901A (zh) | 一种基于全局和局部对比学习的小样本图像分类方法 | |
Wan et al. | Deep feature representation and ball-tree for face sketch recognition | |
Song et al. | Deep discrete hashing with self-supervised pairwise labels | |
Ma | Fixed-point tracking of English reading text based on mean shift and multi-feature fusion | |
Zhang et al. | Pairwise teacher-student network for semi-supervised hashing | |
Huang et al. | Industrial few-shot fractal object detection | |
CN115272688A (zh) | 一种基于元特征的小样本学习图像分类方法 | |
Wang et al. | Multi-scale pyramidal hash learning for traditional building facade image retrieval | |
Yu et al. | An efficient prototype-based model for handwritten text recognition with multi-loss fusion | |
Ren et al. | Video-based emotion recognition using multi-dichotomy RNN-DNN | |
Liu et al. | SCA-YOLOv4: you only look once with squeeze-and-excitation, coordinate attention and adaptively spatial feature fusion |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |