CN117496243A - 基于对比学习的小样本分类方法及系统 - Google Patents
基于对比学习的小样本分类方法及系统 Download PDFInfo
- Publication number
- CN117496243A CN117496243A CN202311462624.0A CN202311462624A CN117496243A CN 117496243 A CN117496243 A CN 117496243A CN 202311462624 A CN202311462624 A CN 202311462624A CN 117496243 A CN117496243 A CN 117496243A
- Authority
- CN
- China
- Prior art keywords
- layer
- feature map
- sample
- module
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 21
- 239000013598 vector Substances 0.000 claims abstract description 106
- 238000000605 extraction Methods 0.000 claims abstract description 46
- 238000012549 training Methods 0.000 claims abstract description 13
- 230000004913 activation Effects 0.000 claims description 196
- 238000011176 pooling Methods 0.000 claims description 73
- 238000000926 separation method Methods 0.000 claims description 61
- 230000006870 function Effects 0.000 claims description 51
- 238000010606 normalization Methods 0.000 claims description 49
- 238000010586 diagram Methods 0.000 claims description 31
- 238000012935 Averaging Methods 0.000 claims description 30
- 238000004364 calculation method Methods 0.000 claims description 11
- 238000010276 construction Methods 0.000 claims description 3
- 238000005096 rolling process Methods 0.000 claims description 2
- 210000002569 neuron Anatomy 0.000 description 4
- 238000005070 sampling Methods 0.000 description 4
- 230000006978 adaptation Effects 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 2
- 230000001965 increasing effect Effects 0.000 description 2
- 238000013507 mapping Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 201000007131 Placental site trophoblastic tumor Diseases 0.000 description 1
- 230000003213 activating effect Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000008033 biological extinction Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000009191 jumping Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0499—Feedforward networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明提供了基于对比学习的小样本分类方法及系统,属于图像分类技术领域。首先构建特征提取网络模型;其次将基础数据集划分为三元样本对,将三元样本对输入到特征提取网络模型,提取正样本特征向量,负样本特征向量,锚点样本特征向量;然后计算正锚相似度和负锚相似度;再将正锚相似度和负锚相似度进行拼接,得到整体相似度,计算整体相似度与标签的损失,判断损失是否满足条件,若满足条件,则计算支持集损失,并更新线性分类器参数。本发明在预训练阶段使用对比学习,使模型有效的学习到样本的空间信息,极大的提高模型在新类任务中的泛化能力。同时使用对比学习进一步训练并加以微调,使模型快速收敛,节约大量的时间与资源。
Description
技术领域
本发明属于图像分类技术领域,具体涉及基于对比学习的小样本分类方法及系统。
背景技术
小样本学习(Few-ShotLearning,FSL)是指在只有少量标注数据的情况下,训练一个模型能够识别新的类别的任务。这是一种具有挑战性的机器学习问题,因为在数据稀缺的情况下,模型很容易过拟合或欠拟合。
传统的深度学习方法面临着重大挑战,因为这些方法通常需要大量的标注数据才能实现良好的泛化性能。然而,在现实世界的应用中,收集和标注大量数据是昂贵和不切实际的。因此,开发能够在少量样本情况下仍然保持高性能的算法变得至关重要。小样本学习任务通常可以分为两个阶段:预训练和微调。预训练阶段,模型在一个较大的基础类数据集上学习,以获得泛化的特征提取器。微调阶段,特征提取器会被进一步调整以适应新的、少量样本的任务。
在这个背景下,对比学习作为一种强有力的无监督学习策略,被广泛应用于预训练阶段。对比学习旨在学习区分不同类别样本的能力,而不是在大量标注数据上进行训练。通过这种方式,模型可以学习到更加鲁棒和泛化的特征表示。然而,如何将这些特征有效地转移到新的、未见过的类别上,仍然是一个值得思考的研究问题。
发明内容
基于上述技术问题,本发明提供基于对比学习的小样本分类方法及系统,在预训练阶段使用对比学习,使得模型获得了特征提取器来适应下游任务。在适应阶段,利用正则化手段,进一步提升模型泛化能力与鲁棒性。
本发明提供了基于对比学习的小样本分类方法,所述方法包括:
步骤S1:构建特征提取网络模型;所述特征提取网络模型包括两个标准卷积模块、两个残差模块、两个分离注意力模块和一个全局平均池化层模块;所述两个标准卷积模块分别为第一标准卷积模块和第二标准卷积模块,所述两个残差模块分别为第一残差模块和第二残差模块,所述两个分离注意力模块分别为第一分离注意力模块和第二分离注意力模块;
步骤S2:将基础数据集划分为三元样本对,将所述三元样本对输入到所述特征提取网络模型,提取正样本特征向量,负样本特征向量,锚点样本特征向量;所述三元样本对包括正样本、锚点样本和负样本;
步骤S3:计算所述锚点样本特征向量与所述正样本特征向量的余弦相似度,得到正锚相似度;计算所述锚点样本特征向量与所述负样本特征向量的余弦相似度,得到负锚相似度;
步骤S4:将所述正锚相似度和所述负锚相似度进行拼接,得到整体相似度,计算所述整体相似度与标签的损失,判断所述损失是否小于或等于第一阈值;如果所述损失大于所述第一阈值,则返回“步骤S2”;如果所述损失小于或等于所述第一阈值,则执行“步骤S5”;
步骤S5:计算支持集损失,并更新线性分类器参数。
可选地,所述第一残差模块,具体包括:
所述第一残差模块包括第一残差输入层、第二标准卷积层、第二规范化激活层、第三标准卷积层、第三规范化激活层、第一元素相加层、第四标准卷积层、第四规范化激活层、第二元素相加层、第五标准卷积层、第五规范化激活层、第一张量拼接层、第六标准卷积层、第一批归一化层、第七标准卷积层、第二批归一化层、第三元素相加层和第一激活函数层;
将所述第一残差输入层的特征图F2依次输入到所述第二标准卷积层和所述第二规范化激活层进行卷积和激活操作,得到特征图F4;
将所述特征图F4依次输入到所述第三标准卷积层和所述第三规范化激活层进行卷积和激活操作,得到特征图F6;
将所述特征图F4和所述特征图F6输入到所述第一元素相加层进行元素相加操作,得到特征图F7;
将所述特征图F7依次输入到所述第四标准卷积层和所述第四规范化激活层进行卷积和激活操作,得到特征图F9;
将所述特征图F7和所述特征图F9输入到所述第二元素相加层进行元素相加操作,得到特征图F10;
将所述特征图F10依次输入到所述第五标准卷积层和所述第五规范化激活层进行卷积和激活操作,得到特征图F12;
将所述特征图F4、所述特征图F6、所述特征图F9和所述特征图F12输入到所述第一张量拼接层进行张量拼接操作,得到特征图F13;
将所述特征图F13依次输入到所述第六标准卷积层和所述第一批归一化层进行卷积和归一化操作,得到特征图F15;
将所述第一残差输入层的所述特征图F2依次输入到所述第七标准卷积层和所述第二批归一化层进行卷积和归一化操作,得到特征图F17;
将所述特征图F15和所述特征图F17输入到所述第三元素相加层进行元素相加操作,得到特征图F18;
将所述特征图F18输入到所述第一激活函数层进行激活操作,得到特征图F19。
可选地,所述第一分离注意力模块,具体包括:
所述第一分离注意力模块包括第一分离注意力输入层、第一全局平均池化层、第一全连接激活层、第二全连接激活层、第一维度扩展层、第二维度扩展层、第一元素相乘层、第一最大池化层、第一平均池化层、第二张量拼接层、第八标准卷积层、第二激活函数层、第二元素相乘层、第一深度可分离卷积层和第六规范化激活层;
将所述第一分离注意力输入层的特征图F19输入到所述第一全局平均池化层进行全局平均池化操作,得到特征图F20;
将所述特征图F20依次输入到所述第一全连接激活层和所述第二全连接激活层进行全连接激活操作,得到特征图F24;
将所述特征图F24依次输入到所述第一维度扩展层和所述第二维度扩展层进行维度扩展操作,得到特征图F26;
将所述第一分离注意力输入层的所述特征图F19和所述特征图F26输入到所述第一元素相乘层进行元素相乘操作,得到特征图F27;
将所述第一分离注意力输入层的所述特征图F19输入到所述第一最大池化层进行最大池化操作,得到特征图F28;
将所述第一分离注意力输入层的所述特征图F19输入到所述第一平均池化层进行平均池化操作,得到特征图F29;
将所述特征图F28和所述特征图F29输入到所述第二张量拼接层进行张量拼接操作,得到特征图F30;
将所述特征图F30依次输入到所述第八标准卷积层和所述第二激活函数层进行卷积和激活操作,得到特征图F32;
将所述特征图F27和所述特征图F32输入到所述第二元素相乘层进行元素相乘操作,得到特征图F33;
将所述特征图F33依次输入到所述第一深度可分离卷积层和所述第六规范化激活层进行卷积核激活操作,得到特征图F35。
可选地,所述计算所述锚点样本特征向量与所述正样本特征向量的余弦相似度,得到正锚相似度,计算所述锚点样本特征向量与所述负样本特征向量的余弦相似度,得到负锚相似度,具体公式为:
式中,cos(ai,pi)为正锚相似度,cos(ai,ni)为负锚相似度,ai为锚点样本特征向量,pi为正样本特征向量,ni为负样本特征向量,||||为特征向量范数,i为第i个样本对,i∈[1,M],M为三元样本对总数。
可选地,所述计算支持集损失,并更新线性分类器参数,具体公式为:
Pj=softmax(W·f(Xj)+B)
式中,Xj为支持集样本图片,Yj为支持集样本图片对应标签,W为分类器权重,f(Xj)为支持集样本图片经过所述特征提取网络模型得到的特征向量,B为偏置量,Pj为预测标签值,Regularization为熵正则化项,crossEntropy为交叉熵损失函数。
本发明还提供基于对比学习的小样本分类系统,所述系统包括:
网络模型构建模块,用于构建特征提取网络模型;所述特征提取网络模型包括两个标准卷积模块、两个残差模块、两个分离注意力模块和一个全局平均池化层模块;所述两个标准卷积模块分别为第一标准卷积模块和第二标准卷积模块,所述两个残差模块分别为第一残差模块和第二残差模块,所述两个分离注意力模块分别为第一分离注意力模块和第二分离注意力模块;
特征向量提取模块,用于将基础数据集划分为三元样本对,将所述三元样本对输入到所述特征提取网络模型,提取正样本特征向量,负样本特征向量,锚点样本特征向量;所述三元样本对包括正样本、锚点样本和负样本;
余弦相似度计算模块,用于计算所述锚点样本特征向量与所述正样本特征向量的余弦相似度,得到正锚相似度;计算所述锚点样本特征向量与所述负样本特征向量的余弦相似度,得到负锚相似度;
训练损失计算模块,用于将所述正锚相似度和所述负锚相似度进行拼接,得到整体相似度,计算所述整体相似度与标签的损失,判断所述损失是否小于或等于第一阈值;如果所述损失大于所述第一阈值,则返回“特征向量提取模块”;如果所述损失小于或等于所述第一阈值,则执行“损失参数更新模块”;
损失参数更新模块,用于计算支持集损失,并更新线性分类器参数。
可选地,所述第一残差模块,具体包括:
第二标准卷积子模块,用于将第一残差输入层的特征图F2依次输入到第二标准卷积层和第二规范化激活层进行卷积和激活操作,得到特征图F4;
第三标准卷积子模块,用于将所述特征图F4依次输入到第三标准卷积层和第三规范化激活层进行卷积和激活操作,得到特征图F6;
第一元素相加子模块,用于将所述特征图F4和所述特征图F6输入到第一元素相加层进行元素相加操作,得到特征图F7;
第四标准卷积子模块,用于将所述特征图F7依次输入到第四标准卷积层和第四规范化激活层进行卷积和激活操作,得到特征图F9;
第二元素相加子模块,用于将所述特征图F7和所述特征图F9输入到第二元素相加层进行元素相加操作,得到特征图F10;
第五标准卷积子模块,用于将所述特征图F10依次输入到第五标准卷积层和第五规范化激活层进行卷积和激活操作,得到特征图F12;
第一张量拼接子模块,用于将所述特征图F4、所述特征图F6、所述特征图F9和所述特征图F12输入到第一张量拼接层进行张量拼接操作,得到特征图F13;
第六标准卷积子模块,用于将所述特征图F13依次输入到第六标准卷积层和第一批归一化层进行卷积和归一化操作,得到特征图F15;
第七标准卷积子模块,用于将所述第一残差输入层的所述特征图F2依次输入到第七标准卷积层和第二批归一化层进行卷积和归一化操作,得到特征图F17;
第三元素相加子模块,用于将所述特征图F15和所述特征图F17输入到第三元素相加层进行元素相加操作,得到特征图F18;
第一激活函数子模块,用于将所述特征图F18输入到第一激活函数层进行激活操作,得到特征图F19。
可选地,所述第一分离注意力模块,具体包括:
将第一分离注意力输入层的特征图F19输入到第一全局平均池化层进行全局平均池化操作,得到特征图F20;
第一二全连接激活子模块,用于将所述特征图F20依次输入到第一全连接激活层和第二全连接激活层进行全连接激活操作,得到特征图F24;
第一二维度扩展子模块,用于将所述特征图F24依次输入到第一维度扩展层和第二维度扩展层进行维度扩展操作,得到特征图F26;
第一元素相乘子模块,用于将所述第一分离注意力输入层的所述特征图F19和所述特征图F26输入到第一元素相乘层进行元素相乘操作,得到特征图F27;
第一最大池化子模块,用于将所述第一分离注意力输入层的所述特征图F19输入到所述第一最大池化层进行最大池化操作,得到特征图F28;
第一平均池化子模块,用于将所述第一分离注意力输入层的所述特征图F19输入到所述第一平均池化层进行平均池化操作,得到特征图F29;
第二张量拼接子模块,用于将所述特征图F28和所述特征图F29输入到第二张量拼接层进行张量拼接操作,得到特征图F30;
第八标准卷积子模块,用于将所述特征图F30依次输入到第八标准卷积层和第二激活函数层进行卷积和激活操作,得到特征图F32;
第二元素相乘子模块,用于将所述特征图F27和所述特征图F32输入到第二元素相乘层进行元素相乘操作,得到特征图F33;
第一深度可分离子模块,用于将所述特征图F33依次输入到第一深度可分离卷积层和第六规范化激活层进行卷积核激活操作,得到特征图F35。
可选地,所述余弦相似度计算模块,具体公式为:
式中,cos(ai,pi)为正锚相似度,cos(ai,ni)为负锚相似度,ai为锚点样本特征向量,pi为正样本特征向量,ni为负样本特征向量,||||为特征向量范数,i为第i个样本对,i∈[1,M],M为三元样本对总数。
可选地,所述损失参数更新模块,具体公式为:
Pj=softmax(W·f(Xj)+B)
式中,Xj为支持集样本图片,Yj为支持集样本图片对应标签,W为分类器权重,f(Xj)为支持集样本图片经过所述特征提取网络模型得到的特征向量,B为偏置量,Pj为预测标签值,Regularization为熵正则化项,crossEntropy为交叉熵损失函数。
本发明与现有技术相比,具有以下有益效果:
本发明通过构建高效的特征提取网络模型,该方法能够更好地理解和区分不同类别的样本,特别是在样本数量有限的情况下。残差模块的使用可以帮助网络学习深层特征而不丢失细节信息,而分离注意力模块则可以增加模型对关键特征的关注,这都有助于提高分类的准确性;使用残差模块和注意力机制可以使网络在增加深度的同时减少训练中的过拟合问题,因为这些模块可以通过跳跃连接和聚焦于重要特征来避免梯度消失和过度依赖少数特征;通过对比学习的方式,模型不仅学习单个样本的特征表示,还学习样本之间的相似性和差异性,这种方式有助于模型在遇到未见过的数据时更好地泛化;在训练时,模型采用小批量的三元样本对,减少了计算资源的需求,这对于资源受限的环境特别有益,也使得这种方法在实际应用中更加灵活和高效;在迁移学习的适应阶段,模型能够快速地适应新任务,这使得模型在实际应用中更加有效和实用;通过在微调阶段引入正则化手段,模型的泛化能力和鲁棒性得到了显著提升。
附图说明
图1为本发明的基于对比学习的小样本分类方法流程图;
图2为本发明的基于对比学习的小样本分类方法中特征提取网络模型图;
图3为本发明的基于对比学习的小样本分类中特征提取网络模型中第一残差模块结构图;
图4为本发明的基于对比学习的小样本分类中特征提取网络模型中第二残差模块结构图;
图5为本发明的基于对比学习的小样本分类中特征提取网络模型中第一分离注意力模块结构图;
图6为本发明的基于对比学习的小样本分类中特征提取网络模型中第二分离注意力模块结构图;
图7为本发明的基于对比学习的小样本分类系统结构图。
具体实施方式
下面结合具体实施案例和附图对本发明作进一步说明,但本发明并不局限于这些实施例。
实施例1
如图1所示,本发明公开基于对比学习的小样本分类方法,方法包括:
步骤S1:构建特征提取网络模型;特征提取网络模型包括两个标准卷积模块、两个残差模块、两个分离注意力模块和一个全局平均池化层模块;两个标准卷积模块分别为第一标准卷积模块和第二标准卷积模块,两个残差模块分别为第一残差模块和第二残差模块,两个分离注意力模块分别为第一分离注意力模块和第二分离注意力模块。
步骤S2:将基础数据集划分为三元样本对,将三元样本对输入到特征提取网络模型,提取正样本特征向量,负样本特征向量,锚点样本特征向量;三元样本对包括正样本、锚点样本和负样本。
步骤S3:计算锚点样本特征向量与正样本特征向量的余弦相似度,得到正锚相似度;计算锚点样本特征向量与负样本特征向量的余弦相似度,得到负锚相似度。
步骤S4:将正锚相似度和负锚相似度进行拼接,得到整体相似度,计算整体相似度与标签的损失,判断损失是否小于或等于第一阈值;如果损失大于第一阈值,则返回“步骤S2”;如果损失小于或等于第一阈值,则执行“步骤S5”。
步骤S5:计算支持集损失,并更新线性分类器参数。
下面对各个步骤进行详细论述:
步骤S1:构建特征提取网络模型。
图2-图6中,Conv2D表示标准卷积层,卷积核尺寸为7×7,3×3和1×1;Strides表示步长,取值1或2;规范化激活层包含批归一化层(Batch Normalization)和激活函数层(Activation(Relu)),规范化激活层选择Relu激活函数,单独的批归一化层(BatchNormalization),单独的激活函数层(Activation(α)),α取值为Relu和Sigmoid;SepConv2D表示深度可分离卷积层,卷积核尺寸为3×3;ResidualInput表示残差输入层,SepattentionInput表示分离注意力输入层;Dense代表全连接层;Multiply(ε,η)表示ε,η进行逐元素相乘;Add(v,ω)表示v,ω进行逐元素相加;Concat(β,δ)表示β,δ进行张量拼接;Expend_dims(θ)表示对θ进行维度扩展;Maxpooling和Avgpooling分别表示最大池化层和平均池化层;Fσ表示特征提取网络中得到的各特征图,σ取值范围为[1,71],σ为整数。
步骤S1具体包括:
将Mini-ImageNet图像信息(84,84,3)划分成三元样本对输入到第一标准卷积层进行卷积操作,得到特征图F1,第一标准卷积层卷积核数量为32,卷积核尺寸为3×3,步长为2;特征图F1为32通道的42×42;将特征图F1输入到第一规范化激活层进行批归一化和激活操作,得到特征图F2;特征图F2为32通道的42×42。
本实施例中,第一标准卷积模块包括第一标准卷积层和第一规范化激活层。
将第一残差输入层的特征图F2(第一规范化激活层的输出)输入到第二标准卷积层进行卷积操作,得到特征图F3,第二标准卷积层卷积核数量为64,卷积核尺寸为1×1,步长为1;特征图F3为64通道的42×42;将特征图F3输入到第二规范化激活层进行批归一化和激活操作,得到特征图F4;特征图F4为64通道的42×42。
将特征图F4输入到第三标准卷积层进行卷积操作,得到特征图F5,第三标准卷积层卷积核数量为64,卷积核尺寸为3×3,步长为1;特征图F5为64通道的42×42;将特征图F5输入到第三规范化激活层进行批归一化和激活操作,得到特征图F6;特征图F6为64通道的42×42。
将特征图F4和特征图F6输入到第一元素相加层进行元素相加操作,得到特征图F7;特征图F7为64通道的42×42。
将特征图F7输入到第四标准卷积层进行卷积操作,得到特征图F8,第四标准卷积层卷积核数量为64,卷积核尺寸为3×3,步长为1;特征图F8为64通道的42×42;将特征图F8输入到第四规范化激活层进行批归一化和激活操作,得到特征图F9;特征图F9为64通道的42×42。
将特征图F7和特征图F9输入到第二元素相加层进行元素相加操作,得到特征图F10;特征图F10为64通道的42×42。
将特征图F10输入到第五标准卷积层进行卷积操作,得到特征图F11,第五标准卷积层卷积核数量为64,卷积核尺寸为3×3,步长为1;特征图F11为64通道的42×42;将特征图F11输入到第五规范化激活层进行批归一化和激活操作,得到特征图F12;特征图F12为64通道的42×42。
将特征图F4、特征图F6、特征图F9和特征图F12输入第一张量拼接层进行张量拼接操作,得到特征图F13;特征图F13为256通道的42×42。
将特征图F13输入到第六标准卷积层进行卷积操作,得到特征图F14,第六标准卷积层卷积核数量为256,卷积核尺寸为1×1,步长为1;特征图F14为256通道的42×42;将特征图F14输入到第一批归一化层进行批归一化操作,得到特征图F15;特征图F15为256通道的42×42。
将第一残差输入层的特征图F2输入到第七标准卷积层进行卷积操作,得到特征图F16,第七标准卷积层卷积核数量为256,卷积核尺寸为1×1,步长为1;特征图F16为256通道的42×42;将特征图F16输入到第二批归一化层进行批归一化操作,得到特征图F17;特征图F17为256通道的42×42。
将特征图F15和特征图F17输入到第三元素相加层进行元素相加操作,得到特征图F18;特征图F18为256通道的42×42。
将特征图F18输入到第一激活函数层进行激活操作,得到特征图F19;特征图F19为256通道的42×42。
将第一分离注意力输入层的特征图F19(第一激活函数层的输出)输入到第一全局平均池化层进行全局平均池化操作,得到特征图F20;特征图F20为(256,);在使用TensorFlow(尤其是Keras)时,通常None在定义模型的输入形状时包含或省略批量维度。这样做是为了允许模型接受任何大小的批次。只处理操作的输出而不在图层定义中指定形状时,不会显式写出None批量大小;只需使用(256,)形状即可。TensorFlow将其理解为当前张量的形状,并且在运算中使用时,TensorFlow将自动处理批量维度。
将特征图F20依次输入到第一全连接激活层和第二全连接激活层进行全连接激活操作,得到特征图F24,第一全连接激活层包括全连接层和激活函数层,全连接层神经元个数为32,激活函数层为Relu激活函数,特征图F22为(32,);第二全连接激活层包括全连接层和激活函数层,全连接层神经元个数为256,激活函数层为Sigmoid激活函数,特征图F24为(256,)。
将特征图F24依次输入到第一维度扩展层和第二维度扩展层进行维度扩展操作,得到特征图F26;特征图F26为256通道的1×1。
将第一分离注意力输入层的特征图F19和特征图F26输入到第一元素相乘层进行元素相乘操作,得到特征图F27;特征图F26为256通道的42×42。
本实施例中,全局平均池化是指取输入张量中每个通道的空间维度(高度和宽度)的平均值。如果输入张量的形状为(batch_size,height,width,channels),则经过此操作,具有形状(batch_size,channels);应用两个密集层(Dense)来学习通道注意权重。第一个将维度降低到filters=256//8,通过Relu激活引入非线性,第二个通过激活Sigmoid将其恢复到原始数量,这将为每个通道提供介于0和1之间的注意力权重。
将第一分离注意力输入层的特征图F19输入到第一最大池化层进行最大池化操作,得到特征图F28;特征图F28为(42,42,1)。
将第一分离注意力输入层的特征图F19输入到第一平均池化层进行平均池化操作,得到特征图F29;特征图F29为(42,42,1)。
将特征图F28和特征图F29输入到第二张量拼接层进行张量拼接操作,得到特征图F30;特征图F30为(42,42,2)。
将特征图F30输入到第八标准卷积层进行卷积操作,得到特征图F31,第八标准卷积层卷积核数量为1,卷积核尺寸为7×7,步长为1;特征图F31为1通道的42×42;将特征图F31输入到第二激活函数层进行激活操作,得到特征图F32;特征图F32为1通道的42×42。
将特征图F27和特征图F32输入到第二元素相乘层进行元素相乘操作,得到特征图F33;特征图F33为256通道的42×42。
将特征图F33输入到第一深度可分离卷积层进行卷积操作,得到特征图F34,第一深度可分离卷积层卷积核数量为256,卷积核尺寸为3×3,步长为2;特征图F34为256通道的21×21;将特征图F34输入到第六规范化激活层进行批归一化和激活操作,得到特征图F35;特征图F35为256通道的21×21。
本实施例中,第一残差模块包括第一残差输入层、第二标准卷积层、第二规范化激活层、第三标准卷积层、第三规范化激活层、第一元素相加层、第四标准卷积层、第四规范化激活层、第二元素相加层、第五标准卷积层、第五规范化激活层、第一张量拼接层、第六标准卷积层、第一批归一化层、第七标准卷积层、第二批归一化层、第三元素相加层和第一激活函数层。
本实施例中,第一分离注意力模块包括第一分离注意力输入层、第一全局平均池化层、第一全连接激活层、第二全连接激活层、第一维度扩展层、第二维度扩展层、第一元素相乘层、第一最大池化层、第一平均池化层、第二张量拼接层、第八标准卷积层、第二激活函数层、第二元素相乘层、第一深度可分离卷积层和第六规范化激活层。
将第二残差输入层的特征图F35(第六规范化激活层的输出)输入到第九标准卷积层进行卷积操作,得到特征图F36,第九标准卷积层卷积核数量为96,卷积核尺寸为1×1,步长为1;特征图F36为96通道的21×21;将特征图F36输入到第七规范化激活层进行批归一化和激活操作,得到特征图F37;特征图F37为96通道的21×21。
将特征图F37输入到第十标准卷积层进行卷积操作,得到特征图F38,第十标准卷积层卷积核数量为96,卷积核尺寸为3×3,步长为1;特征图F38为96通道的21×21;将特征图F38输入到第八规范化激活层进行批归一化和激活操作,得到特征图F39;特征图F39为96通道的21×21。
将特征图F37和特征图F39输入到第四元素相加层进行元素相加操作,得到特征图F40;特征图F40为96通道的21×21。
将特征图F40输入到第十一标准卷积层进行卷积操作,得到特征图F41,第十一标准卷积层卷积核数量为96,卷积核尺寸为3×3,步长为1;特征图F41为96通道的21×21;将特征图F41输入到第九规范化激活层进行批归一化和激活操作,得到特征图F42;特征图F42为96通道的21×21。
将特征图F40和特征图F42输入到第五元素相加层进行元素相加操作,得到特征图F43;特征图F43为96通道的21×21。
将特征图F43输入到第十二标准卷积层进行卷积操作,得到特征图F44,第十二标准卷积层卷积核数量为96,卷积核尺寸为3×3,步长为1;特征图F44为64通道的42×42;将特征图F44输入到第十规范化激活层进行批归一化和激活操作,得到特征图F45;特征图F45为96通道的21×21。
将特征图F37、特征图F39、特征图F42和特征图F45输入第三张量拼接层进行张量拼接操作,得到特征图F46;特征图F46为384通道的21×21。
将特征图F46输入到第十三标准卷积层进行卷积操作,得到特征图F47,第十三标准卷积层卷积核数量为384,卷积核尺寸为1×1,步长为1;特征图F47为384通道的21×21;将特征图F47输入到第三批归一化层进行批归一化操作,得到特征图F48;特征图F48为384通道的21×21。
将第二残差输入层的特征图F35输入到第十四标准卷积层进行卷积操作,得到特征图F49,第十四标准卷积层卷积核数量为384,卷积核尺寸为1×1,步长为1;特征图F49为384通道的21×21;将特征图F49输入到第四批归一化层进行批归一化操作,得到特征图F50;特征图F50为384通道的21×21。
将特征图F48和特征图F50输入到第六元素相加层进行元素相加操作,得到特征图F51;特征图F51为384通道的21×21。
将特征图F51输入到第三激活函数层进行激活操作,得到特征图F52;特征图F52为384通道的21×21。
将第二分离注意力输入层的特征图F52(第三激活函数层的输出)输入到第二全局平均池化层进行全局平均池化操作,得到特征图F53;特征图F53为(384,)。
将特征图F53依次输入到第三全连接激活层和第四全连接激活层进行全连接激活操作,得到特征图F57,第三全连接激活层包括全连接层和激活函数层,全连接层神经元个数为48,激活函数层为Relu激活函数,特征图F55为(48,);第四全连接激活层包括全连接层和激活函数层,全连接层神经元个数为384,激活函数层为Sigmoid激活函数,特征图F57为(384,)。
将特征图F57依次输入到第三维度扩展层和第四维度扩展层进行维度扩展操作,得到特征图F59;特征图F59为384通道的1×1。
将第二分离注意力输入层的特征图F52和特征图F59输入到第三元素相乘层进行元素相乘操作,得到特征图F60;特征图F60为384通道的21×21。
将第二分离注意力输入层的特征图F52输入到第二最大池化层进行最大池化操作,得到特征图F61;特征图F61为(21,21,1)。
将第二分离注意力输入层的特征图F52输入到第二平均池化层进行平均池化操作,得到特征图F62;特征图F62为(21,21,1)。
将特征图F61和特征图F62输入到第四张量拼接层进行张量拼接操作,得到特征图F63;特征图F63为(21,21,2)。
将特征图F63输入到第十五标准卷积层进行卷积操作,得到特征图F64,第十五标准卷积层卷积核数量为1,卷积核尺寸为7×7,步长为1;特征图F64为1通道的21×21;将特征图F64输入到第四激活函数层进行激活操作,得到特征图F65;特征图F65为1通道的21×21。
将特征图F60和特征图F65输入到第四元素相乘层进行元素相乘操作,得到特征图F66;特征图F66为384通道的21×21。
将特征图F66输入到第二深度可分离卷积层进行卷积操作,得到特征图F67,第二深度可分离卷积层卷积核数量为384,卷积核尺寸为3×3,步长为2;特征图F67为384通道的11×11;将特征图F67输入到第十一规范化激活层进行批归一化和激活操作,得到特征图F68;特征图F68为384通道的11×11。
本实施例中,第二残差模块包括第二残差输入层、第九标准卷积层、第七规范化激活层、第十标准卷积层、第八规范化激活层、第四元素相加层、第十一标准卷积层、第九规范化激活层、第五元素相加层、第十二标准卷积层、第十规范化激活层、第三张量拼接层、第十三标准卷积层、第三批归一化层、第十四标准卷积层、第四批归一化层、第六元素相加层和第三激活函数层。
本实施例中,第二分离注意力模块包括第二分离注意力输入层、第二全局平均池化层、第三全连接激活层、第四全连接激活层、第三维度扩展层、第四维度扩展层、第三元素相乘层、第二最大池化层、第二平均池化层、第四张量拼接层、第十五标准卷积层、第四激活函数层、第四元素相乘层、第二深度可分离卷积层和第十一规范化激活层。
将第十一规范化激活层的输出F68输入到第十六标准卷积层进行卷积操作,得到特征图F69,第十六标准卷积层卷积核数量为512,卷积核尺寸为3×3,步长为2;特征图F69为512通道的6×6;将特征图F69输入到第十二规范化激活层进行批归一化和激活操作,得到特征图F70;特征图F70为512通道的6×6。
本实施例中,第二标准卷积模块包括第十六标准卷积层和第十二规范化激活层。
将第十二规范化激活层的输出F70输入到第三全局平均池化层进行操作,得到特征图F71;特征图F72为(512,)。
本实施例中,全局平均池化层模块包括第三全局平均池化层。
本实施例中,特征提取网络模型包括两个标准卷积模块、两个残差模块、两个分离注意力模块和一个全局平均池化层模块;两个标准卷积模块分别为第一标准卷积模块和第二标准卷积模块,两个残差模块分别为第一残差模块和第二残差模块,两个分离注意力模块分别为第一分离注意力模块和第二分离注意力模块。
步骤S2:将基础数据集划分为三元样本对,三元样本对输入到特征提取网络模型,提取正样本特征向量,负样本特征向量,锚点样本特征向量。
步骤S2具体包括:
利用度量学习的数据划分方法对基础数据集(Dbase)进行重新划分,将数据集划分成许多个三元样本对,每个三元样本对中包含三张图片。划分的规则为随机选择一张图片为正样本,并在正样本所在类别选择一张非正样本图片作为锚点样本,再排除该类别在其他类别随机选择一张图片作为副样本。首先,利用特征提取网络模型,得到一个用来提取特征的骨干网络然后,分批次载入之前划分好的三元样本对。最后,利用/>分别提取正样本特征向量positivei,负样本特征向量negativei,锚点样本向量anchori。
本实施例中,三元样本对包括正样本、锚点样本和负样本。
步骤S3:计算锚点样本特征向量与正样本特征向量的余弦相似度,得到正锚相似度;计算锚点样本特征向量与负样本特征向量的余弦相似度,得到负锚相似度。
步骤S3具体包括:
式中,cos(ai,pi)为正锚相似度,cos(ai,ni)为负锚相似度,ai为锚点样本特征向量(anchori),pi为正样本特征向量(positivei),ni为负样本特征向量(negativei),||||为特征向量范数,i为第i个样本对,i∈[1,M],M为三元样本对总数。余弦相似度不仅可以用于评估特征之间的相关性。确定哪些特征在某个任务中具有较高的相关性,从而进行特征的选择和排除。还可用于聚类和分类任务中的样本相似性度量,将相似的样本聚集在一起
步骤S4:将正锚相似度和负锚相似度进行拼接,得到整体相似度,计算整体相似度与标签的损失,判断损失是否小于或等于第一阈值;如果损失大于第一阈值,则返回“步骤S2”;如果损失小于或等于第一阈值,则执行“步骤S5”。
步骤S4具体包括:
式中,S为每个三元样本对进行两次相似度计算,将结果拼接为一个张量,张量长度为2M,M为三元样本对总数,前M个为ai与pi的余弦相似度,后M个为ai与ni的余弦相似度。
设置一个长度2M的一维向量L,L=[1,1,1,1,1…0,0,0,0,0],前M个为正样本对应的标签,标签值为1,后M个为负样本对应的标签,标签值为0,计算整体相似度与标签的损失,损失公式为:
式中,Si为相似度值,Li为标签值;I为张量S和向量L的第I个值。
判断损失是否小于或等于第一阈值;如果损失大于第一阈值,则返回“步骤S2”,重新提取特征向量;如果损失小于或等于第一阈值,则执行“步骤S5”。
步骤S5:计算支持集损失,并更新线性分类器参数。
步骤S5具体包括:
Pj=softmax(W·f(Xj)+B)
式中,Xj为支持集样本图片,Yj为支持集样本图片对应标签,W为分类器权重,f(Xj)为支持集样本图片经过特征提取网络模型得到的特征向量,B为偏置量,Pj为预测标签值,Regularization为熵正则化项,crossEntropy为交叉熵损失函数。
本实施例中,(Xj,Yj)为支持集(support set)中的一个样本,X是图片,Y是标签,利用特征提取网络将Xj映射成一个特征向量,把这个特征向量f(Xj)输入softmax分类器,分类器会输出预测标签值Pj。对于softmax分类器的参数,初始阶段,固定W=M,B=0,首先利用support set中每个类别中所有样本进行特征提取,将其表示为一个代表该类别的特征向量,将这些向量堆叠组成系数矩阵M;利用交叉熵损失进行微调,目标函数为crossEntropy(Yj,Pj),支持集中有几十个标注样本,每个样本都对应一个crossEntropy,将所有的crossEntropy加起来作为损失函数;即用支持集中所有的图片和标签来学习这个分类器。对目标函数为做最小化,最小化公式为:
式中,Pj是向量第j个元素,让预测Pj尽量接近真实标签Yj,最小化公式是对分类器参数W和B求的,再加上一个正则化项防止过拟合,这里的Regularization使用了entropyRegularization;这样可以鼓励模型输出更为均匀的概率分布,增强模型的泛化能力。
实施例2
如图7所示,本发明公开基于对比学习的小样本分类系统,系统包括:
网络模型构建模块10,用于构建特征提取网络模型;特征提取网络模型包括两个标准卷积模块、两个残差模块、两个分离注意力模块和一个全局平均池化层模块;两个标准卷积模块分别为第一标准卷积模块和第二标准卷积模块,两个残差模块分别为第一残差模块和第二残差模块,两个分离注意力模块分别为第一分离注意力模块和第二分离注意力模块。
特征向量提取模块20,用于将基础数据集划分为三元样本对,将三元样本对输入到特征提取网络模型,提取正样本特征向量,负样本特征向量,锚点样本特征向量;三元样本对包括正样本、锚点样本和负样本。
余弦相似度计算模块30,用于计算锚点样本特征向量与正样本特征向量的余弦相似度,得到正锚相似度;计算锚点样本特征向量与负样本特征向量的余弦相似度,得到负锚相似度。
训练损失计算模块40,用于将正锚相似度和负锚相似度进行拼接,得到整体相似度,计算整体相似度与标签的损失,判断损失是否小于或等于第一阈值;如果损失大于第一阈值,则返回“特征向量提取模块”;如果损失小于或等于第一阈值,则执行“损失参数更新模块”。
损失参数更新模块50,用于计算支持集损失,并更新线性分类器参数。
作为一种可选地实施方式,本发明第一残差模块,具体包括:
第二标准卷积子模块,用于将第一残差输入层的特征图F2依次输入到第二标准卷积层和第二规范化激活层进行卷积和激活操作,得到特征图F4;
第三标准卷积子模块,用于将特征图F4依次输入到第三标准卷积层和第三规范化激活层进行卷积和激活操作,得到特征图F6。
第一元素相加子模块,用于将特征图F4和特征图F6输入到第一元素相加层进行元素相加操作,得到特征图F7。
第四标准卷积子模块,用于将特征图F7依次输入到第四标准卷积层和第四规范化激活层进行卷积和激活操作,得到特征图F9。
第二元素相加子模块,用于将特征图F7和特征图F9输入到第二元素相加层进行元素相加操作,得到特征图F10。
第五标准卷积子模块,用于将特征图F10依次输入到第五标准卷积层和第五规范化激活层进行卷积和激活操作,得到特征图F12。
第一张量拼接子模块,用于将特征图F4、特征图F6、特征图F9和特征图F12输入到第一张量拼接层进行张量拼接操作,得到特征图F13。
第六标准卷积子模块,用于将特征图F13依次输入到第六标准卷积层和第一批归一化层进行卷积和归一化操作,得到特征图F15。
第七标准卷积子模块,用于将第一残差输入层的特征图F2依次输入到第七标准卷积层和第二批归一化层进行卷积和归一化操作,得到特征图F17。
第三元素相加子模块,用于将特征图F15和特征图F17输入到第三元素相加层进行元素相加操作,得到特征图F18。
第一激活函数子模块,用于将特征图F18输入到第一激活函数层进行激活操作,得到特征图F19。
作为一种可选地实施方式,本发明第一分离注意力模块,具体包括:
将第一分离注意力输入层的特征图F19输入到第一全局平均池化层进行全局平均池化操作,得到特征图F20。
第一二全连接激活子模块,用于将特征图F20依次输入到第一全连接激活层和第二全连接激活层进行全连接激活操作,得到特征图F24。
第一二维度扩展子模块,用于将特征图F24依次输入到第一维度扩展层和第二维度扩展层进行维度扩展操作,得到特征图F26。
第一元素相乘子模块,用于将第一分离注意力输入层的特征图F19和特征图F26输入到第一元素相乘层进行元素相乘操作,得到特征图F27。
第一最大池化子模块,用于将第一分离注意力输入层的特征图F19输入到第一最大池化层进行最大池化操作,得到特征图F28。
第一平均池化子模块,用于将第一分离注意力输入层的特征图F19输入到第一平均池化层进行平均池化操作,得到特征图F29。
第二张量拼接子模块,用于将特征图F28和特征图F29输入到第二张量拼接层进行张量拼接操作,得到特征图F30。
第八标准卷积子模块,用于将特征图F30依次输入到第八标准卷积层和第二激活函数层进行卷积和激活操作,得到特征图F32。
第二元素相乘子模块,用于将特征图F27和特征图F32输入到第二元素相乘层进行元素相乘操作,得到特征图F33。
第一深度可分离子模块,用于将特征图F33依次输入到第一深度可分离卷积层和第六规范化激活层进行卷积核激活操作,得到特征图F35。
作为一种可选地实施方式,本发明余弦相似度计算模块30,具体公式:
/>
式中,cos(ai,Pi)为正锚相似度,cos(ai,ni)为负锚相似度,ai为锚点样本特征向量,pi为正样本特征向量,ni为负样本特征向量,||||为特征向量范数,i为第i个样本对,i∈[1,M],M为三元样本对总数。
作为一种可选地实施方式,本发明损失参数更新模块50,具体公式为:
Pj=softmax(W·f(Xj)+B)
式中,Xj为支持集样本图片,Yj为支持集样本图片对应标签,W为分类器权重,f(Xj)为支持集样本图片经过特征提取网络模型得到的特征向量,B为偏置量,Pj为预测标签值,Regularization为熵正则化项,crossEntropy为交叉熵损失函数。
实施例3
本发明在具体数据集进行实验分析,Mini-ImageNet数据集是一个常见的小样本学习的数据集。它包含从ILSVRC-2012中抽取的100个类,每个类包含600张大小为84X84的图像。Flower102数据集是一个细粒度数据集,包括英国常见的102个不同类别的花卉,每个类别包含40到258张图片,共8189张图片。这些图像有很大的尺度,姿势和光线变化。此外,还有一些类别有很大的变化,以及一些非常相似的类别。
基于Keras框架,构建训练好的特征提取模型作为骨干网络,对于预训练阶段,使用SGD优化器,动量为0.9,学习率为0.0005,衰减因子为0.999。在miniImageNet上,在1个GPU上以批次大小为64的方式分别训练了200个epochs,没有使用数据增强;对于适应阶段,使用动量为0.9的SGD优化器学习率为0.005,对训练了500个epochs的结果进行了记录;还应用一致的抽样来评估性能;对于数据集中的新类分割,测试少量任务的抽样遵循一个确定性的顺序;一致性抽样在相同数量的抽样任务下获得更好的模型比较。
表1总结了针对mini-ImageNet的1-shot和5-shot分类结果。尤其是在5-shot分类结果上,取得了81.93+0.27%的准确率。这表明,本发明提出方法,有效的提高了所学嵌入模型的可移植性,通过对比学习,模型对相同类数据更加敏感,当相同类数据增加时,模型更容易将他们归为同一类。
表1 1-shot和5-shot分类结果
Model | 1-shot | 5-shot |
Matching Networks | 43.56±0.84 | 55.31±0.73 |
Prototypical Networks | 48.70±1.84 | 63.11±0.92 |
DN4 | 51.24±0.74 | 71.03±0.64 |
CC+rot | 62.95±0.06 | 79.87±0.33 |
wDAE | 61.07±0.15 | 76.75±0.11 |
PSST | 64.16±0.44 | 80.64±0.32 |
TADAM | 58.50±0.30 | 76.70±0.30 |
ProtoNets+TRAML | 60.31±0.48 | 77.94±0.57 |
MTL | 61.20±1.80 | 75.50±0.80 |
MetaOptNet | 62.64±0.61 | 78.63±0.46 |
ConstellationNet | 64.89±0.23 | 79.95±0.17 |
Ours | 65.55±0.77 | 81.93±0.27 |
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (10)
1.基于对比学习的小样本分类方法,其特征在于,所述方法包括:
步骤S1:构建特征提取网络模型;所述特征提取网络模型包括两个标准卷积模块、两个残差模块、两个分离注意力模块和一个全局平均池化层模块;所述两个标准卷积模块分别为第一标准卷积模块和第二标准卷积模块,所述两个残差模块分别为第一残差模块和第二残差模块,所述两个分离注意力模块分别为第一分离注意力模块和第二分离注意力模块;
步骤S2:将基础数据集划分为三元样本对,将所述三元样本对输入到所述特征提取网络模型,提取正样本特征向量,负样本特征向量,锚点样本特征向量;所述三元样本对包括正样本、锚点样本和负样本;
步骤S3:计算所述锚点样本特征向量与所述正样本特征向量的余弦相似度,得到正锚相似度;计算所述锚点样本特征向量与所述负样本特征向量的余弦相似度,得到负锚相似度;
步骤S4:将所述正锚相似度和所述负锚相似度进行拼接,得到整体相似度,计算所述整体相似度与标签的损失,判断所述损失是否小于或等于第一阈值;如果所述损失大于所述第一阈值,则返回“步骤S2”;如果所述损失小于或等于所述第一阈值,则执行“步骤S5”;
步骤S5:计算支持集损失,并更新线性分类器参数。
2.根据权利要求1所述的基于对比学习的小样本分类方法,其特征在于,所述第一残差模块,具体包括:
所述第一残差模块包括第一残差输入层、第二标准卷积层、第二规范化激活层、第三标准卷积层、第三规范化激活层、第一元素相加层、第四标准卷积层、第四规范化激活层、第二元素相加层、第五标准卷积层、第五规范化激活层、第一张量拼接层、第六标准卷积层、第一批归一化层、第七标准卷积层、第二批归一化层、第三元素相加层和第一激活函数层;
将所述第一残差输入层的特征图F2依次输入到所述第二标准卷积层和所述第二规范化激活层进行卷积和激活操作,得到特征图F4;
将所述特征图F4依次输入到所述第三标准卷积层和所述第三规范化激活层进行卷积和激活操作,得到特征图F6;
将所述特征图F4和所述特征图F6输入到所述第一元素相加层进行元素相加操作,得到特征图F7;
将所述特征图F7依次输入到所述第四标准卷积层和所述第四规范化激活层进行卷积和激活操作,得到特征图F9;
将所述特征图F7和所述特征图F9输入到所述第二元素相加层进行元素相加操作,得到特征图F10;
将所述特征图F10依次输入到所述第五标准卷积层和所述第五规范化激活层进行卷积和激活操作,得到特征图F12;
将所述特征图F4、所述特征图F6、所述特征图F9和所述特征图F12输入到所述第一张量拼接层进行张量拼接操作,得到特征图F13;
将所述特征图F13依次输入到所述第六标准卷积层和所述第一批归一化层进行卷积和归一化操作,得到特征图F15;
将所述第一残差输入层的所述特征图F2依次输入到所述第七标准卷积层和所述第二批归一化层进行卷积和归一化操作,得到特征图F17;
将所述特征图F15和所述特征图F17输入到所述第三元素相加层进行元素相加操作,得到特征图F18;
将所述特征图F18输入到所述第一激活函数层进行激活操作,得到特征图F19。
3.根据权利要求1所述的基于对比学习的小样本分类方法,其特征在于,所述第一分离注意力模块,具体包括:
所述第一分离注意力模块包括第一分离注意力输入层、第一全局平均池化层、第一全连接激活层、第二全连接激活层、第一维度扩展层、第二维度扩展层、第一元素相乘层、第一最大池化层、第一平均池化层、第二张量拼接层、第八标准卷积层、第二激活函数层、第二元素相乘层、第一深度可分离卷积层和第六规范化激活层;
将所述第一分离注意力输入层的特征图F19输入到所述第一全局平均池化层进行全局平均池化操作,得到特征图F20;
将所述特征图F20依次输入到所述第一全连接激活层和所述第二全连接激活层进行全连接激活操作,得到特征图F24;
将所述特征图F24依次输入到所述第一维度扩展层和所述第二维度扩展层进行维度扩展操作,得到特征图F26;
将所述第一分离注意力输入层的所述特征图F19和所述特征图F26输入到所述第一元素相乘层进行元素相乘操作,得到特征图F27;
将所述第一分离注意力输入层的所述特征图F19输入到所述第一最大池化层进行最大池化操作,得到特征图F28;
将所述第一分离注意力输入层的所述特征图F19输入到所述第一平均池化层进行平均池化操作,得到特征图F29;
将所述特征图F28和所述特征图F29输入到所述第二张量拼接层进行张量拼接操作,得到特征图F30;
将所述特征图F30依次输入到所述第八标准卷积层和所述第二激活函数层进行卷积和激活操作,得到特征图F32;
将所述特征图F27和所述特征图F32输入到所述第二元素相乘层进行元素相乘操作,得到特征图F33;
将所述特征图F33依次输入到所述第一深度可分离卷积层和所述第六规范化激活层进行卷积核激活操作,得到特征图F35。
4.根据权利要求1所述的基于对比学习的小样本分类方法,其特征在于,所述计算所述锚点样本特征向量与所述正样本特征向量的余弦相似度,得到正锚相似度,计算所述锚点样本特征向量与所述负样本特征向量的余弦相似度,得到负锚相似度,具体公式为:
式中,cos(ai,pi)为正锚相似度,cos(ai,ni)为负锚相似度,ai为锚点样本特征向量,pi为正样本特征向量,ni为负样本特征向量,|| ||为特征向量范数,i为第i个样本对,i∈[1,M],M为三元样本对总数。
5.根据权利要求1所述的基于对比学习的小样本分类方法,其特征在于,所述计算支持集损失,并更新线性分类器参数,具体公式为:
Pj=softmax((W·f(Xj)+B)
式中,Xj为支持集样本图片,Yj为支持集样本图片对应标签,W为分类器权重,f(Xj)为支持集样本图片经过所述特征提取网络模型得到的特征向量,B为偏置量,Pj为预测标签值,Regularization为熵正则化项,crossEntropy为交叉熵损失函数。
6.基于对比学习的小样本分类系统,其特征在于,所述系统包括:
网络模型构建模块,用于构建特征提取网络模型;所述特征提取网络模型包括两个标准卷积模块、两个残差模块、两个分离注意力模块和一个全局平均池化层模块;所述两个标准卷积模块分别为第一标准卷积模块和第二标准卷积模块,所述两个残差模块分别为第一残差模块和第二残差模块,所述两个分离注意力模块分别为第一分离注意力模块和第二分离注意力模块;
特征向量提取模块,用于将基础数据集划分为三元样本对,将所述三元样本对输入到所述特征提取网络模型,提取正样本特征向量,负样本特征向量,锚点样本特征向量;所述三元样本对包括正样本、锚点样本和负样本;
余弦相似度计算模块,用于计算所述锚点样本特征向量与所述正样本特征向量的余弦相似度,得到正锚相似度;计算所述锚点样本特征向量与所述负样本特征向量的余弦相似度,得到负锚相似度;
训练损失计算模块,用于将所述正锚相似度和所述负锚相似度进行拼接,得到整体相似度,计算所述整体相似度与标签的损失,判断所述损失是否小于或等于第一阈值;如果所述损失大于所述第一阈值,则返回“特征向量提取模块”;如果所述损失小于或等于所述第一阈值,则执行“损失参数更新模块”;
损失参数更新模块,用于计算支持集损失,并更新线性分类器参数。
7.根据权利要求6所述的基于对比学习的小样本分类系统,其特征在于,所述第一残差模块,具体包括:
第二标准卷积子模块,用于将第一残差输入层的特征图F2依次输入到第二标准卷积层和第二规范化激活层进行卷积和激活操作,得到特征图F4;
第三标准卷积子模块,用于将所述特征图F4依次输入到第三标准卷积层和第三规范化激活层进行卷积和激活操作,得到特征图F6;
第一元素相加子模块,用于将所述特征图F4和所述特征图F6输入到第一元素相加层进行元素相加操作,得到特征图F7;
第四标准卷积子模块,用于将所述特征图F7依次输入到第四标准卷积层和第四规范化激活层进行卷积和激活操作,得到特征图F9;
第二元素相加子模块,用于将所述特征图F7和所述特征图F9输入到第二元素相加层进行元素相加操作,得到特征图F10;
第五标准卷积子模块,用于将所述特征图F10依次输入到第五标准卷积层和第五规范化激活层进行卷积和激活操作,得到特征图F12;
第一张量拼接子模块,用于将所述特征图F4、所述特征图F6、所述特征图F9和所述特征图F12输入到第一张量拼接层进行张量拼接操作,得到特征图F13;
第六标准卷积子模块,用于将所述特征图F13依次输入到第六标准卷积层和第一批归一化层进行卷积和归一化操作,得到特征图F15;
第七标准卷积子模块,用于将所述第一残差输入层的所述特征图F2依次输入到第七标准卷积层和第二批归一化层进行卷积和归一化操作,得到特征图F17;
第三元素相加子模块,用于将所述特征图F15和所述特征图F17输入到第三元素相加层进行元素相加操作,得到特征图F18;
第一激活函数子模块,用于将所述特征图F18输入到第一激活函数层进行激活操作,得到特征图F19。
8.根据权利要求6所述的基于对比学习的小样本分类系统,其特征在于,所述第一分离注意力模块,具体包括:
将第一分离注意力输入层的特征图F19输入到第一全局平均池化层进行全局平均池化操作,得到特征图F20;
第一二全连接激活子模块,用于将所述特征图F20依次输入到第一全连接激活层和第二全连接激活层进行全连接激活操作,得到特征图F24;
第一二维度扩展子模块,用于将所述特征图F24依次输入到第一维度扩展层和第二维度扩展层进行维度扩展操作,得到特征图F26;
第一元素相乘子模块,用于将所述第一分离注意力输入层的所述特征图F19和所述特征图F26输入到第一元素相乘层进行元素相乘操作,得到特征图F27;
第一最大池化子模块,用于将所述第一分离注意力输入层的所述特征图F19输入到所述第一最大池化层进行最大池化操作,得到特征图F28;
第一平均池化子模块,用于将所述第一分离注意力输入层的所述特征图F19输入到所述第一平均池化层进行平均池化操作,得到特征图F29;
第二张量拼接子模块,用于将所述特征图F28和所述特征图F29输入到第二张量拼接层进行张量拼接操作,得到特征图F30;
第八标准卷积子模块,用于将所述特征图F30依次输入到第八标准卷积层和第二激活函数层进行卷积和激活操作,得到特征图F32;
第二元素相乘子模块,用于将所述特征图F27和所述特征图F32输入到第二元素相乘层进行元素相乘操作,得到特征图F33;
第一深度可分离子模块,用于将所述特征图F33依次输入到第一深度可分离卷积层和第六规范化激活层进行卷积核激活操作,得到特征图F35。
9.根据权利要求6所述的基于对比学习的小样本分类系统,其特征在于,所述余弦相似度计算模块,具体公式为:
式中,cos(ai,pi)为正锚相似度,cos(ai,ni)为负锚相似度,ai为锚点样本特征向量,pi为正样本特征向量,ni为负样本特征向量,|| ||为特征向量范数,i为第i个样本对,i∈[1,M],M为三元样本对总数。
10.根据权利要求6所述的基于对比学习的小样本分类系统,其特征在于,所述损失参数更新模块,具体公式为:
Pj=softmax(W·f(Xj)+B)
式中,Xj为支持集样本图片,Yj为支持集样本图片对应标签,W为分类器权重,f(Xj)为支持集样本图片经过所述特征提取网络模型得到的特征向量,B为偏置量,Pj为预测标签值,Regularization为熵正则化项,crossEntropy为交叉熵损失函数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311462624.0A CN117496243B (zh) | 2023-11-06 | 2023-11-06 | 基于对比学习的小样本分类方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311462624.0A CN117496243B (zh) | 2023-11-06 | 2023-11-06 | 基于对比学习的小样本分类方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117496243A true CN117496243A (zh) | 2024-02-02 |
CN117496243B CN117496243B (zh) | 2024-05-31 |
Family
ID=89680975
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311462624.0A Active CN117496243B (zh) | 2023-11-06 | 2023-11-06 | 基于对比学习的小样本分类方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117496243B (zh) |
Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109961089A (zh) * | 2019-02-26 | 2019-07-02 | 中山大学 | 基于度量学习和元学习的小样本和零样本图像分类方法 |
CN114359674A (zh) * | 2022-01-12 | 2022-04-15 | 浙江大学 | 基于度量学习的非侵入式负荷识别方法 |
CN115687894A (zh) * | 2022-10-27 | 2023-02-03 | 天津大学 | 一种基于小样本学习的跌倒检测系统及方法 |
CN115936961A (zh) * | 2022-11-21 | 2023-04-07 | 南京信息工程大学 | 基于少样本对比学习网络的隐写分析方法、设备及介质 |
CN116127298A (zh) * | 2023-02-22 | 2023-05-16 | 北京邮电大学 | 基于三元组损失的小样本射频指纹识别方法 |
CN116385950A (zh) * | 2023-04-07 | 2023-07-04 | 广西科学院 | 一种小样本条件下电力线路隐患目标检测方法 |
CN116504317A (zh) * | 2023-04-27 | 2023-07-28 | 平安科技(深圳)有限公司 | 疾病与基因关系预测方法、装置、电子设备及介质 |
CN116523840A (zh) * | 2023-03-30 | 2023-08-01 | 苏州大学 | 一种基于深度学习的肺部ct图像检测系统以及方法 |
US20230281972A1 (en) * | 2022-05-13 | 2023-09-07 | Nanjing University Of Aeronautics And Astronautics | Few-shot defect detection method based on metric learning |
WO2023185243A1 (zh) * | 2022-03-29 | 2023-10-05 | 河南工业大学 | 基于注意力调制上下文空间信息的表情识别方法 |
CN116977747A (zh) * | 2023-08-28 | 2023-10-31 | 中国地质大学(北京) | 基于多路多尺度特征孪生网络的小样本高光谱分类方法 |
-
2023
- 2023-11-06 CN CN202311462624.0A patent/CN117496243B/zh active Active
Patent Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109961089A (zh) * | 2019-02-26 | 2019-07-02 | 中山大学 | 基于度量学习和元学习的小样本和零样本图像分类方法 |
CN114359674A (zh) * | 2022-01-12 | 2022-04-15 | 浙江大学 | 基于度量学习的非侵入式负荷识别方法 |
WO2023185243A1 (zh) * | 2022-03-29 | 2023-10-05 | 河南工业大学 | 基于注意力调制上下文空间信息的表情识别方法 |
US20230281972A1 (en) * | 2022-05-13 | 2023-09-07 | Nanjing University Of Aeronautics And Astronautics | Few-shot defect detection method based on metric learning |
CN115687894A (zh) * | 2022-10-27 | 2023-02-03 | 天津大学 | 一种基于小样本学习的跌倒检测系统及方法 |
CN115936961A (zh) * | 2022-11-21 | 2023-04-07 | 南京信息工程大学 | 基于少样本对比学习网络的隐写分析方法、设备及介质 |
CN116127298A (zh) * | 2023-02-22 | 2023-05-16 | 北京邮电大学 | 基于三元组损失的小样本射频指纹识别方法 |
CN116523840A (zh) * | 2023-03-30 | 2023-08-01 | 苏州大学 | 一种基于深度学习的肺部ct图像检测系统以及方法 |
CN116385950A (zh) * | 2023-04-07 | 2023-07-04 | 广西科学院 | 一种小样本条件下电力线路隐患目标检测方法 |
CN116504317A (zh) * | 2023-04-27 | 2023-07-28 | 平安科技(深圳)有限公司 | 疾病与基因关系预测方法、装置、电子设备及介质 |
CN116977747A (zh) * | 2023-08-28 | 2023-10-31 | 中国地质大学(北京) | 基于多路多尺度特征孪生网络的小样本高光谱分类方法 |
Non-Patent Citations (2)
Title |
---|
FANG ZHAO ET AL.: "Dynamic Conditional Networksfor Few-Shot Learning", COMPUTER VISION, 7 October 2018 (2018-10-07) * |
PEIPEI XIA ET AL.: "Learning similarity with cosine similarity ensemble", INFORMATION SCIENCES, vol. 307, 31 December 2015 (2015-12-31) * |
Also Published As
Publication number | Publication date |
---|---|
CN117496243B (zh) | 2024-05-31 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110516095B (zh) | 基于语义迁移的弱监督深度哈希社交图像检索方法和系统 | |
CN112100346B (zh) | 基于细粒度图像特征和外部知识相融合的视觉问答方法 | |
CN111368909B (zh) | 一种基于卷积神经网络深度特征的车标识别方法 | |
CN112036447B (zh) | 零样本目标检测系统及可学习语义和固定语义融合方法 | |
CN111460894B (zh) | 一种基于卷积神经网络的车标智能检测方法 | |
CN113222068B (zh) | 基于邻接矩阵指导标签嵌入的遥感图像多标签分类方法 | |
Nguyen et al. | Satellite image classification using convolutional learning | |
Chen et al. | Dictionary learning from ambiguously labeled data | |
CN113378938B (zh) | 一种基于边Transformer图神经网络的小样本图像分类方法及系统 | |
CN113159023A (zh) | 基于显式监督注意力机制的场景文本识别方法 | |
CN114067385A (zh) | 基于度量学习的跨模态人脸检索哈希方法 | |
CN113469186A (zh) | 一种基于少量点标注的跨域迁移图像分割方法 | |
López-Cifuentes et al. | Attention-based knowledge distillation in scene recognition: the impact of a dct-driven loss | |
CN117496243B (zh) | 基于对比学习的小样本分类方法及系统 | |
CN117011515A (zh) | 基于注意力机制的交互式图像分割模型及其分割方法 | |
CN111709442A (zh) | 一种面向图像分类任务的多层字典学习方法 | |
CN115423090A (zh) | 一种面向细粒度识别的类增量学习方法 | |
CN115100694A (zh) | 一种基于自监督神经网络的指纹快速检索方法 | |
CN113449751B (zh) | 基于对称性和群论的物体-属性组合图像识别方法 | |
You et al. | A new multiple max-pooling integration module and cross multiscale deconvolution network based on image semantic segmentation | |
CN115424275A (zh) | 一种基于深度学习技术的渔船船牌号识别方法及系统 | |
CN115082762A (zh) | 基于区域建议网络中心对齐的目标检测无监督域适应系统 | |
CN115512174A (zh) | 应用二次IoU损失函数的无锚框目标检测方法 | |
CN111931788A (zh) | 基于复值的图像特征提取方法 | |
Xu et al. | Dictionary learning with mutually reinforcing group-graph structures |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |