CN114943859B - 面向小样本图像分类的任务相关度量学习方法及装置 - Google Patents

面向小样本图像分类的任务相关度量学习方法及装置 Download PDF

Info

Publication number
CN114943859B
CN114943859B CN202210479781.1A CN202210479781A CN114943859B CN 114943859 B CN114943859 B CN 114943859B CN 202210479781 A CN202210479781 A CN 202210479781A CN 114943859 B CN114943859 B CN 114943859B
Authority
CN
China
Prior art keywords
module
task related
model
task
formula
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210479781.1A
Other languages
English (en)
Other versions
CN114943859A (zh
Inventor
李晓旭
杨世丞
刘俊
燕锦涛
安文娟
张文斌
李睿凡
马占宇
陶剑
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Lanzhou University of Technology
Original Assignee
Lanzhou University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Lanzhou University of Technology filed Critical Lanzhou University of Technology
Priority to CN202210479781.1A priority Critical patent/CN114943859B/zh
Publication of CN114943859A publication Critical patent/CN114943859A/zh
Application granted granted Critical
Publication of CN114943859B publication Critical patent/CN114943859B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Molecular Biology (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种面向小样本图像分类的任务相关度量学习方法及装置,方法主要由数据预处理阶段、构建网络模型阶段、训练模型参数阶段和测试模型性能阶段组成,本发明通过考虑不同任务之间的差异性,引入注意力机制的思想,并学习任务相关的空间映射,利用任务自适应度量学习的方式,解决了小样本图像分类中存在的自适应度量学习问题,从而提高在小样本条件下目标任务分类的准确性,改善了图像的分类效果,具有很高的实用价值。

Description

面向小样本图像分类的任务相关度量学习方法及装置
技术领域
本发明涉及计算机视觉领域中图像分类,尤其涉及一种面向小样本图像类内共性特征的任务相关度量学习方法及装置。
背景技术
近年来,随着计算机技术的发展,人们浏览的信息日益丰富,每天都有大量图片被上传到网络,由于数量巨大,人工已经无法对此进行分类。在很多大样本图像分类任务上,机器的识别性能已经超越人类。然而,当样本量比较少时,机器的识别水平仍与人类存在较大差距。因此,研究高效可靠的图片分类算法有很迫切的社会需求。
人类具体通过极少量样本识别一个新物体的能力,例如小朋友只需要看过书中的个别图片,就可以准确的判断什么是“香蕉”或者是“草莓”。小样本学习指的是研究人员希望机器学习模型在学习一定类别的大量数据后,遇到新的类别后,只需要少量的数据就可以快速的学习,实现“小样本学习”。
小样本分类属于小样本学习范畴,往往包含类别空间不相交的两类数据,即基类数据和新类数据。小样本分类旨在利用基类数据学习的知识和新类数据的少量标记样本(支持样本)来学习分类规则,准确预测新类任务中未标记样本(查询样本)的类别。
在小样本图像分类的研究方法中,基于深度度量的方法简单而且高效,主要通过比较样本间或者样本与类原型间的距离来判断类别。常常结合数据增强、迁移学习等技术来弥补数据量不足以及模型容易过拟合的缺陷,在很多小样本分类任务上获得了较好的分类性能。但与大样图像分类相比,现有小样本图像分类的性能仍不尽人意,很大程度上限制了小样本图像分类技术的实用化,在自适应的度量学习中还面临着以下问题亟待解决:
现有小样本分类方法中,大多假设小样本分类任务使用一个单一的度量方式,例如余弦距离、欧氏距离或一个可学习的度量网络模块。不同的任务包含不同的类别,有些任务适用余弦距离,有些任务适用欧氏距离。因此,如何构建任务自适应的度量也是小样本图像分类值得研究的问题。
发明内容
本发明针对上述技术问题,提出一种面向小样本图像分类的任务相关度量学习方法及装置,引入了注意力机制的思想,利用任务自适应度量学习的方式,通过考虑不同任务之间的差异性,并学习任务相关的空间映射,解决了小样本图像分类中存在任务自适应的度量问题,对于图像的分类效果十分明显,具有很高的实用价值。
为了实现上述目的,本发明提供如下技术方案:
一方面,本发明提供了一种面向小样本图像分类的任务相关度量学习方法,包括以下步骤:
S1、对数据进行预处理,其中数据包括训练集Dtrain和测试集Dtest,训练集Dtrain和测试集Dtest的类别空间互斥;
S2、构建面向小样本图像分类的任务相关度量学习模型,模型由嵌入模块
Figure BDA0003627103170000021
和任务相关度量模块组成;其中,嵌入模块包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;任务相关度量模块由注意力模块和余弦度量模块组成;
S3、将训练集数据送入面向小样本图像分类的任务相关度量学习模型进行训练,求解模型参数;
S4、利用训练后的面向小样本图像分类的任务相关度量学习模型对新类任务进行预测,测评模型的性能。
进一步地,步骤S1的预处理方法为:从训练集Dtrain中随机选出C个类别,每个类别中随机选出M个样本,其中K个样本作为支持样本Si,其余M-K个样本作为查询样本Qi,Si和Qi构成一个任务Ti,同样对于测试集Dtest也有任务Tk
进一步地,步骤S2中,每个卷积块包含一个带有64个滤波器的3×3的卷积,一个批量归一化,一个relu非线性层,一个2×2最大池化层,裁剪了最后两个块的最大池化层,全连接层共128维。
进一步地,步骤S3具体包括:
S301、对于Dtrain中的一个任务Ti,首先将所有支持样本和查询样本输入嵌入模块
Figure BDA0003627103170000037
中;
S302、利用嵌入模块中的卷积神经网络,将支持样本依次经过卷积层、池化层和激活层,最终提取图像的特征
Figure BDA0003627103170000038
S303、将支持样本特征Fs∈RHW×C分别作为V和K输入到任务相关度量模块中;
S304、将查询样本中特征Fq∈RHW×C,将其作为Q输入到任务相关度量模块中,其中H和W代表特征空间的大小,C代表特征的通道数;
S305、将V,K,Q分别经过三个权重不同的线性层
Figure BDA0003627103170000039
将提取出来的特征投影到低维,得到转换后特征,表示为/>
Figure BDA0003627103170000031
公式如下:
Figure BDA0003627103170000032
在公式(1)中,Fs代表支持样本特征,Fq代表查询样本特征,Wv,Wk,Wq代表三个权重不同的线性层,
Figure BDA0003627103170000033
分别代表Fs经过Wv、Fs经过Wk、Fq经过Wq所得到的转化后的特征;公式(1)表示将V,K,Q经过Wv,Wk,Wq三个权重不同的线性层投影到低维;
S306、利用公式(2)计算所有支持样本的预测概率,公式如下:
Figure BDA0003627103170000034
在公式(2)中,
Figure BDA0003627103170000035
代表矩阵的对应元素相乘,/>
Figure BDA0003627103170000036
代表经过公式(1)转化后的特征,C代表特征的通道数,softmax代表softmax激活函数,FA代表经过公式(2)后得到的加权特征;公式(2)表示求得加权注意力权重后的特征;
S307、将Fa再经过一个线性层后得到任务自适应的支持样本特征FA∈RHW×C
S308、将查询样本特征Q和任务自适应的支持样本特征FA共同输入到余弦度量模块中,度量模块采用余弦分类器,用于查询样本的分类;
S309、使用交叉熵损失函数计算支持样本与查询样本的分类预测损失l0,将l0作为整个网络的总损失loss;
S310、根据求得的loss使用mini-batch和Adam优化器更新嵌入模块
Figure BDA0003627103170000044
和任务相关度量模型的可学习参数,重复训练多个任务,直到网络收敛。
进一步地,步骤S308中余弦度量模块运算公式如下:
Figure BDA0003627103170000041
在公式(3)中,Fq代表查询样本特征,FA代表任务自适应的支持样本特征,
Figure BDA0003627103170000042
代表矩阵的对应元素相乘,||A||表示求矩阵A的二范数,F代表求得的余弦相似度矩阵,公式(3)表示求出任务自适应的支持样本特征FA和查询样本特征Fq之间的余弦相似度矩阵。
进一步地,步骤S309中的交叉熵损失函数公式如下:
Figure BDA0003627103170000043
在公式(4)中,n代表种类数量,y代表类标签,若类别是i,则yi=1且其他位为0,pi代表类别是i的概率,其值为公式(3)算出来的F矩阵中的对应位置的元素值,loss代表计算后得到的当前损失值,公式(4)表示根据交叉熵损失函数,计算当前的网络模型情况下的损失值。
进一步地,S310中使用的Adam学习率自适应优化算法具体步骤如下:
S3101、对数据进行初始化:
vdW=0,SdW=0,vdb=0,Sdb=0
W代表W1,W2,…,Wn的集合,b代表b1,b2,…,bn的集合,在第t次迭代中,用当前的mini-batch计算W和b的微分dW,db;
S3102、Momentum算法:
根据公式(5)和(6)计算梯度微分的指数加权平均数:
vdW=β1vdW+(1-β1)dW (5)
vdb=β1vdb+(1-β1)db (6)
S3103、RMSprop算法:
根据公式(7)和(8)计算梯度微分平方的指数加权平均数:
SdW=β2SdW+(1-β2)(dW)2 (7)
Sdb=β2Sdb+(1-β2)(db)2 (8)
S3104、对两种算法进行偏差修正:
根据公式(9)和(10)进行Momentum算法偏差修正:
Figure BDA0003627103170000051
Figure BDA0003627103170000052
根据公式(11)和(12)进行RMSprop算法偏差修正;
Figure BDA0003627103170000053
Figure BDA0003627103170000054
S3105、根据公式(13)和(14)进行梯度下降,更新参数:
Figure BDA0003627103170000055
Figure BDA0003627103170000056
在公式(5)-(14)中,vdW,vdb,SdW,Sdb分别代表有偏差的一阶和二阶矩估计,t代表次数,α代表学习率,ε代表用于数值稳定的小常数,β12代表矩估计的指数衰减率,dW,db分别代表W和b的微分,
Figure BDA0003627103170000057
代表经过偏差修正后的一阶和二阶矩估计。
进一步地,步骤S4的具体步骤为:
S401、将新类的任务输入训练好的嵌入模块
Figure BDA0003627103170000058
中;
S402、将嵌入模块输出的矩阵特征经过任务相关度量模块,得到查询样本与支持样本各类别的余弦度量;
S403、将相似度最高的类作为预测标签,根据预测结果评估模型性能。
另一方面,本发明还提供了一种面向小样本图像分类的任务相关度量学习装置,用以实现上述的任一项方法,包括以下模块:
数据预处理模块:用于对数据进行预处理,将数据划分成为训练集和测试集并且确定模型的训练方式;
网络模型构建模块:用于引入注意力机制和自适应度量学习,构建面向小样本图像分类的任务相关度量学习模型,模型由嵌入模块
Figure BDA0003627103170000061
和任务相关度量模块组成;其中,嵌入模块包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;任务相关度量模块由注意力模块和余弦度量模块组成;
训练模型参数模块:用于面向小样本图像分类的任务相关度量学习模型进行训练,求解模型参数;
测试模型性能模块:利用训练好的面向小样本图像分类的任务相关度量学习模型对新类的任务进行预测,测评模型的性能。
与现有技术相比,本发明的有益效果为:
本发明引入注意力机制的思想,建立了一种面向小样本图像分类的任务相关度量学习方法及装置,通过考虑不同任务之间的差异性,并学习任务相关的空间映射,利用任务自适应度量学习的方式,解决了小样本图像分类中存在的自适应度量学习问题,从而提高在小样本条件下目标任务分类的准确性,改善了图像的分类效果,具有很高的实用价值。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明中记载的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的面向小样本图像分类的任务相关度量学习方法的阶段流程图。
图2为本发明实施例提供的面向小样本图像分类的任务相关度量学习模型结构图。
图3为本发明实施例提供的面向小样本图像分类的任务相关度量学习模型的功能模块构成图。
图4为本发明实施例提供的嵌入模块
Figure BDA0003627103170000062
结构图。
图5为本发明实施例提供的特征矩阵经过三个权重不同的线性层示意图。
图6为本发明实施例提供的注意力机制运算示意图。
具体实施方式
下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。本发明中的实施例,本领域技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
根据本文公开的一个方面,提供了一种面向小样本图像分类的任务相关度量学习方法,如图1所示,包括以下阶段步骤:
S1、数据预处理阶段:对数据进行预处理,其中数据包括训练集和测试集;
S2、构建网络模型阶段:引入注意力机制和自适应度量学习,构建面向小样本图像分类的任务相关度量学习模型;
S3、训练模型参数阶段:将训练集数据送入面向小样本图像分类的任务相关度量学习模型进行训练,求解模型参数;
S4、测试模型性能阶段:利用训练后的面向小样本图像分类的任务相关度量学习模型对新类任务进行预测,测评模型的性能。
在一些实施例中,所述阶段步骤S1包括以下子步骤:
S101、将数据
Figure BDA0003627103170000071
分为/>
Figure BDA0003627103170000072
Figure BDA0003627103170000073
两个部分,且这两个部分的类别空间互斥。将Dtrain作为基类数据训练模型,Dtest作为新类数据测评模型性能;
S102、对于C-way K-shot分类任务,从Dtrain中随机选出C个类别,每个类别中随机选出M个样本,其中K个样本作为支持样本Si,其余M-K个样本作为查询样本Qi,Si和Qi构成一个任务Ti,同样对于Dtest也有任务Tk
在一些实施例中,所述步骤S2包括以下步骤:
构建面向小样本图像分类的任务相关度量学习模型,其结构如图2所示,网络模型划分为嵌入式模块和任务相关度量模块,如图3所示。嵌入式模块由输入层、卷积层、池化层以及激活函数组成,如图4所示,目的是为了提取样本的局部特征;其中,在遵循四层卷积架构来形成特征提取器
Figure BDA0003627103170000086
每个块包含一个带有64个滤波器的3×3的卷积,一个批量归一化,一个relu非线性层,一个2×2最大池化层,裁剪了最后两个块的最大池化层,全连接层共128维。
任务相关度量模块由注意力模块和余弦度量模块组成,目的是通过考虑不同任务之间的差异性,并学习任务相关的空间映射,从而提高在小样本条件下目标任务分类的准确性。
在一些实施例中,所述步骤S3包括以下子步骤:
S301、对于Dtrain中的一个任务Ti,首先将所有支持样本和查询样本输入嵌入模块
Figure BDA0003627103170000087
中;
S302、利用嵌入模块中的卷积神经网络,将支持样本依次经过卷积层、池化层和激活层,最终提取图像的特征
Figure BDA0003627103170000088
S303、将支持样本特征Fs∈RHW×C分别作为V和K输入到任务相关度量模块中;
S304、将查询样本中特征Fq∈RHW×C,将其作为Q输入到任务相关度量模块中,其中H和W代表特征空间的大小,C代表特征的通道数;
S305、将V,K,Q分别经过三个权重不同的线性层
Figure BDA0003627103170000089
将提取出来的特征投影到低维,得到转换后特征,表示为/>
Figure BDA00036271031700000810
公式如下:
Figure BDA0003627103170000081
在公式(1)中,Fs代表支持样本特征,Fq代表查询样本特征,Wv,Wk,Wq代表三个权重不同的线性层,
Figure BDA0003627103170000082
分别代表Fs经过Wv、Fs经过Wk、Fq经过Wq所得到的转化后的特征;公式(1)表示将V,K,Q经过Wv,Wk,Wq三个权重不同的线性层投影到低维,如图5。
S306、利用公式(2)计算所有支持样本的预测概率,公式如下:
Figure BDA0003627103170000083
在公式(2)中,
Figure BDA0003627103170000084
代表矩阵的对应元素相乘,/>
Figure BDA0003627103170000085
代表经过公式(1)转化后的特征,C代表特征的通道数,softmax代表softmax激活函数,FA代表经过公式(2)后得到的加权特征,公式(2)表示求得加权注意力权重后的特征,如图6所示。
在公式(2)中,对经过运算后的
Figure BDA0003627103170000091
矩阵使用softmax激活函数,/>
Figure BDA0003627103170000092
表示两个特征矩阵的对应元素相乘,即是将Fs k转置后与Fq q进行矩阵相乘,softmax激活函数分别对每一行、每一列进行归一化得到注意力权重矩阵α∈RHW×HW,使用这种方法可以保持梯度的稳定。
S307、将Fa再经过一个线性层后得到任务自适应的支持样本特征FA∈RHW×C
S308、接着将查询样本特征Q和任务自适应的支持样本特征FA共同输入到余弦度量模块中,度量模块采用余弦分类器,用于查询样本的分类。
采用的余弦度量模块,其运算公式如下:
Figure BDA0003627103170000093
在公式(3)中,Fq代表查询样本特征,FA代表任务自适应的支持样本特征,
Figure BDA0003627103170000094
代表矩阵的对应元素相乘,||A||表示求矩阵A的二范数,F代表求得的余弦相似度矩阵。公式(3)表示求出任务自适应的支持样本特征FA和查询样本特征Fq之间的余弦相似度矩阵。
S309、使用交叉熵损失函数计算支持样本与查询样本的分类预测损失l0,将l0作为整个网络的总损失loss。
交叉熵损失函数公式如下:
Figure BDA0003627103170000095
在公式(4)中,n代表种类数量,y代表类标签,若类别是i,则yi=1且其他位为0,pi代表类别是i的概率,其值为公式(3)算出来的F矩阵中的对应位置的元素值,loss代表计算后得到的当前损失值。公式(4)表示根据交叉熵损失函数,计算当前的网络模型情况下的损失值。
S310、根据求得的loss使用mini-batch和Adam优化器更新嵌入模块
Figure BDA0003627103170000106
和任务相关度量模型的可学习参数,重复训练多个任务,直到网络收敛。
使用的Adam学习率自适应优化算法具体步骤如下:
S3101、为了简化描述,W代表W1,W2,…,Wn的集合,b代表b1,b2,…,bn的集合。
对数据进行初始化:vdW=0,SdW=0,vdb=0,Sdb=0。
在第t次迭代中,用当前的mini-batch计算W和b的微分dW,db。
S3102、Momentum算法
公式计算梯度微分的指数加权平均数:
vdW=β1vdW+(1-β1)dW (5)
vdb=β1vdb+(1-β1)db (6)
S3103、RMSprop算法公式计算梯度微分平方的指数加权平均数:
SdW=β2SdW+(1-β2)(dW)2 (7)
Sdb=β2Sdb+(1-β2)(db)2 (8)
S3104、对两种算法都进行偏差修正:
1)Momentum算法偏差修正:
Figure BDA0003627103170000101
Figure BDA0003627103170000102
2)RMSprop算法偏差修正;
Figure BDA0003627103170000103
Figure BDA0003627103170000104
S3105、进行梯度下降,更新参数:
Figure BDA0003627103170000105
Figure BDA0003627103170000111
在公式(5)-(14)中,vdW,vdb,SdW,Sdb分别代表有偏差的一阶和二阶矩估计,t代表次数,α代表学习率,ε代表用于数值稳定的小常数,β12代表矩估计的指数衰减率,dW,db分别代表W和b的微分,
Figure BDA0003627103170000112
代表经过偏差修正后的一阶和二阶矩估计。
Mini-batch算法优点:把数据分为若干个批,按批来更新参数,这样,一个批中的一组数据共同决定了本次梯度的方向,下降起来就不容易跑偏,减少了随机性。另一方面因为批的样本数与整个数据集相比小了很多,计算量也不是很大,更加便于计算。
Adam优化算法优点:动量直接并入了梯度一阶矩(指数加权)的估计,将动量应用于缩放后的梯度。包括了偏差修正步骤,修正从原点初始化的一阶矩(动量项)和(非中心的)二阶矩估计。
在一些实施例中,所述步骤S4包括以下子步骤:
S401、将新类的任务输入训练好的嵌入模块
Figure BDA0003627103170000113
中;
S402、将嵌入模块输出的矩阵特征经过任务相关度量模块,得到查询样本与支持样本各类别的余弦度量;
S403、将相似度最高的类作为预测标签,根据预测结果评估模型性能。
根据本文公开的另一个方面,本发明还提供了一种面向小样本图像分类的任务相关度量学习装置,用于实现上述面向小样本图像分类的任务相关度量学习方法,包括:
数据预处理模块:对数据进行预处理,将数据划分成为训练集和测试集并且确定模型的训练方式;
网络模型构建模块:引入注意力机制和自适应度量学习,构建面向小样本图像分类的任务相关度量学习模型;模型由嵌入模块
Figure BDA0003627103170000114
和任务相关度量模块组成;其中,嵌入模块包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;任务相关度量模块由注意力模块和余弦度量模块组成;
训练模型参数模块:用面向小样本图像分类的任务相关度量学习模型进行训练,求解模型参数;
测试模型性能模块:利用训练好的面向小样本图像分类的任务相关度量学习模型对新类的任务进行预测,测评模型的性能。
以上结合附图对所提出的面向小样本图像分类的任务相关度量学习方法及模型的具体实施方式进行了阐述。通过以上实施方式的描述,所属领域的技术人员可以清楚的了解该方法以及装置的实施。
在此提供的算法和显示不与任何特定计算机、虚拟系统或者其他设备固有相关。各种通用系统也可以与基于在此地启示一起使用。根据上面的描述,构造这类系统所要求的结构是显而易见的。此外,本文公开的也不针对任何特定的编程语言。但是应当了解,可以利用各种编程语言实现在此描述的本文公开的内容,并且上面对特定语言所做的描述是为了披露本文公开的最佳实施方式。
类似的,应当理解,为了使本文尽量精简并且帮助理解各个公开方面中的一个或多个,在上面对本文公开的示例性实施例的描述中,本文公开的各个特征有时被一起分组到单个实施例、图、或者对其的描述中。然而,并不应将该公开的方法解释成反映如下示意图:即要求所保护的本文公开的要求比在每个权力要求中所明确记载的特征具有更多的特征。更确切地说,如下面的权力要求书所反映的那样,公开方面在于少于前面公开的单个实施例的所有特征。因此,遵循具体实施方式的权利要求书由此明确地并入该具体实施方式,其中每个权利要求本身都作为本公开的单独实施例子。
以上所述实施例,仅为本申请的具体实施方式,用以说明本申请的技术方案,而非对其限制,本申请的保护范围并不局限于此,尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,其依然可以对前述实施例所记载的技术方案进行修改或可轻易想到变化,或者对其中部分技术特殊进行等同替换;而这些修改、变化或者替换,并不使相应技术方案的本质脱离本申请实施例技术方案的精神和范围。都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应所述以权利要求的保护范围为准。

Claims (6)

1.一种面向小样本图像分类的任务相关度量学习方法,其特征在于,包括以下步骤:
S1、对数据进行预处理,其中数据包括训练集Dtrain和测试集Dtest,训练集Dtrain和测试集Dtest的类别空间互斥;
S2、构建面向小样本图像分类的任务相关度量学习模型,模型由嵌入模块
Figure FDA0004201747750000016
和任务相关度量模块组成;其中,嵌入模块包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;任务相关度量模块由注意力模块和余弦度量模块组成;
S3、将训练集数据送入面向小样本图像分类的任务相关度量学习模型进行训练,求解模型参数;
步骤S3具体包括:
S301、对于Dtrain中的一个任务Ti,首先将所有支持样本和查询样本输入嵌入模块
Figure FDA0004201747750000017
中;
S302、利用嵌入模块中的卷积神经网络,将支持样本依次经过卷积层、池化层和激活层,最终提取图像的特征
Figure FDA0004201747750000015
S303、将支持样本特征Fs∈RHW×C分别作为V和K输入到任务相关度量模块中;
S304、将查询样本中特征Fq∈RHW×C,将其作为Q输入到任务相关度量模块中,其中H和W代表特征空间的大小,C代表特征的通道数;
S305、将V,K,Q分别经过三个权重不同的线性层
Figure FDA0004201747750000014
将提取出来的特征投影到低维,得到转换后特征,表示为/>
Figure FDA0004201747750000011
公式如下:
Figure FDA0004201747750000012
在公式(1)中,Fs代表支持样本特征,Fq代表查询样本特征,Wv,Wk,Wq代表三个权重不同的线性层,
Figure FDA0004201747750000013
分别代表Fs经过Wv、Fs经过Wk、Fq经过Wq所得到的转化后的特征;公式(1)表示将V,K,Q经过Wv,Wk,Wq三个权重不同的线性层投影到低维;
S306、利用公式(2)计算所有支持样本的预测概率,公式如下:
Figure FDA0004201747750000021
在公式(2)中,
Figure FDA0004201747750000026
代表矩阵的对应元素相乘,/>
Figure FDA0004201747750000022
代表经过公式(1)转化后的特征,C代表特征的通道数,softmax代表softmax激活函数,Fa代表经过公式(2)后得到的加权特征;公式(2)表示求得加权注意力权重后的特征;
S307、将Fa再经过一个线性层后得到任务自适应的支持样本特征FA∈RHW×C
S308、将查询样本特征Q和任务自适应的支持样本特征FA共同输入到余弦度量模块中,度量模块采用余弦分类器,用于查询样本的分类;余弦度量模块运算公式如下:
Figure FDA0004201747750000023
在公式(3)中,Fq代表查询样本特征,FA代表任务自适应的支持样本特征,
Figure FDA0004201747750000027
代表矩阵的对应元素相乘,||A||表示求矩阵A的二范数,F代表求得的余弦相似度矩阵,公式(3)表示求出任务自适应的支持样本特征FA和查询样本特征Fq之间的余弦相似度矩阵;
S309、使用交叉熵损失函数计算支持样本与查询样本的分类预测损失l0,将l0作为整个网络的总损失loss;交叉熵损失函数公式如下:
Figure FDA0004201747750000024
在公式(4)中,n代表种类数量,y代表类标签,若类别是i,则yi=1且其他位为0,pi代表类别是i的概率,其值为公式(3)算出来的F矩阵中的对应位置的元素值,loss代表计算后得到的当前损失值,公式(4)表示根据交叉熵损失函数,计算当前的网络模型情况下的损失值;
S310、根据求得的loss使用mini-batch和Adam优化器更新嵌入模块
Figure FDA0004201747750000025
和任务相关度量模型的可学习参数,重复训练多个任务,直到网络收敛;
S4、利用训练后的面向小样本图像分类的任务相关度量学习模型对新类任务进行预测,测评模型的性能。
2.根据权利要求1所述的面向小样本图像分类的任务相关度量学习方法,其特征在于,步骤S1的预处理方法为:从训练集Dtrain中随机选出C个类别,每个类别中随机选出M个样本,其中K个样本作为支持样本Si,其余M-K个样本作为查询样本Qi,Si和Qi构成一个任务Ti,同样对于测试集Dtest也有任务Tk
3.根据权利要求1所述的面向小样本图像分类的任务相关度量学习方法,其特征在于,步骤S2中,每个卷积块包含一个带有64个滤波器的3×3的卷积,一个批量归一化,一个relu非线性层,一个2×2最大池化层,裁剪了最后两个块的最大池化层,全连接层共128维。
4.根据权利要求1所述的面向小样本图像分类的任务相关度量学习方法,其特征在于,S310中使用的Adam学习率自适应优化算法具体步骤如下:
S3101、对数据进行初始化:
vdW=0,SdW=0,vdb=0,Sdb=0
W代表W1,W2,…,Wn的集合,b代表b1,b2,…,bn的集合,在第t次迭代中,用当前的mini-batch计算W和b的微分dW,db;
S3102、Momentum算法:
根据公式(5)和(6)计算梯度微分的指数加权平均数:
vdW=β1vdW+(1-β1)dW (5)
vdb=β1vdb+(1-β1)db (6)
S3103、RMSprop算法:
根据公式(7)和(8)计算梯度微分平方的指数加权平均数:
SdW=β2SdW+(1-β2)(dW)2 (7)
Sdb=β2Sdb+(1-β2)(db)2 (8)
S3104、对两种算法进行偏差修正:
根据公式(9)和(10)进行Momentum算法偏差修正:
Figure FDA0004201747750000031
Figure FDA0004201747750000032
根据公式(11)和(12)进行RMSprop算法偏差修正;
Figure FDA0004201747750000041
Figure FDA0004201747750000042
S3105、根据公式(13)和(14)进行梯度下降,更新参数:
Figure FDA0004201747750000043
Figure FDA0004201747750000044
在公式(5)-(14)中,vdW,vdb,SdW,Sdb分别代表有偏差的一阶和二阶矩估计,t代表次数,α代表学习率,ε代表用于数值稳定的小常数,β12代表矩估计的指数衰减率,dW,db分别代表W和b的微分,
Figure FDA0004201747750000045
代表经过偏差修正后的一阶和二阶矩估计。
5.根据权利要求1所述的面向小样本图像分类的任务相关度量学习方法,其特征在于,步骤S4的具体步骤为:
S401、将新类的任务输入训练好的嵌入模块
Figure FDA0004201747750000046
中;
S402、将嵌入模块输出的矩阵特征经过任务相关度量模块,得到查询样本与支持样本各类别的余弦度量;
S403、将相似度最高的类作为预测标签,根据预测结果评估模型性能。
6.一种面向小样本图像分类的任务相关度量学习装置,其特征在于,用以实现权利要求1-5任一项所述的面向小样本图像分类的任务相关度量学习方法,包括以下模块:
数据预处理模块:用于对数据进行预处理,将数据划分成为训练集和测试集并且确定模型的训练方式;
网络模型构建模块:用于引入注意力机制和自适应度量学习,构建面向小样本图像分类的任务相关度量学习模型,模型由嵌入模块
Figure FDA0004201747750000047
和任务相关度量模块组成;其中,嵌入模块包含四个卷积块,每个卷积块均包括卷积层、池化层以及非线性激活函数;任务相关度量模块由注意力模块和余弦度量模块组成;
训练模型参数模块:用于面向小样本图像分类的任务相关度量学习模型进行训练,求解模型参数;
测试模型性能模块:利用训练好的面向小样本图像分类的任务相关度量学习模型对新类的任务进行预测,测评模型的性能。
CN202210479781.1A 2022-05-05 2022-05-05 面向小样本图像分类的任务相关度量学习方法及装置 Active CN114943859B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210479781.1A CN114943859B (zh) 2022-05-05 2022-05-05 面向小样本图像分类的任务相关度量学习方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210479781.1A CN114943859B (zh) 2022-05-05 2022-05-05 面向小样本图像分类的任务相关度量学习方法及装置

Publications (2)

Publication Number Publication Date
CN114943859A CN114943859A (zh) 2022-08-26
CN114943859B true CN114943859B (zh) 2023-06-20

Family

ID=82906413

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210479781.1A Active CN114943859B (zh) 2022-05-05 2022-05-05 面向小样本图像分类的任务相关度量学习方法及装置

Country Status (1)

Country Link
CN (1) CN114943859B (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116168257B (zh) * 2023-04-23 2023-07-04 安徽大学 基于样本生成的小样本图像分类方法、设备及存储介质
CN116612335B (zh) * 2023-07-18 2023-09-19 贵州大学 一种基于对比学习的少样本细粒度图像分类方法

Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109919183A (zh) * 2019-01-24 2019-06-21 北京大学 一种基于小样本的图像识别方法、装置、设备及存储介质
CN109961089A (zh) * 2019-02-26 2019-07-02 中山大学 基于度量学习和元学习的小样本和零样本图像分类方法
CN110020682A (zh) * 2019-03-29 2019-07-16 北京工商大学 一种基于小样本学习的注意力机制关系对比网络模型方法
CN112288013A (zh) * 2020-10-30 2021-01-29 中南大学 基于元度量学习的小样本遥感场景分类方法
CN112836773A (zh) * 2021-04-08 2021-05-25 河海大学 一种基于全局注意力残差网络的高光谱图像分类方法
CN113255701A (zh) * 2021-06-24 2021-08-13 军事科学院系统工程研究院网络信息研究所 一种基于绝对-相对学习架构的小样本学习方法和系统
CN113537305A (zh) * 2021-06-29 2021-10-22 复旦大学 一种基于匹配网络少样本学习的图像分类方法
CN113655479A (zh) * 2021-08-16 2021-11-16 西安电子科技大学 基于可变形卷积和双注意力的小样本sar目标分类方法
CN113723562A (zh) * 2021-09-10 2021-11-30 中国计量大学 一种基于小样本学习的胸部x光图像多种疾病分类方法
CN113963165A (zh) * 2021-09-18 2022-01-21 中国科学院信息工程研究所 一种基于自监督学习的小样本图像分类方法及系统
CN114067160A (zh) * 2021-11-22 2022-02-18 重庆邮电大学 基于嵌入平滑图神经网络的小样本遥感图像场景分类方法

Patent Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109919183A (zh) * 2019-01-24 2019-06-21 北京大学 一种基于小样本的图像识别方法、装置、设备及存储介质
CN109961089A (zh) * 2019-02-26 2019-07-02 中山大学 基于度量学习和元学习的小样本和零样本图像分类方法
CN110020682A (zh) * 2019-03-29 2019-07-16 北京工商大学 一种基于小样本学习的注意力机制关系对比网络模型方法
CN112288013A (zh) * 2020-10-30 2021-01-29 中南大学 基于元度量学习的小样本遥感场景分类方法
CN112836773A (zh) * 2021-04-08 2021-05-25 河海大学 一种基于全局注意力残差网络的高光谱图像分类方法
CN113255701A (zh) * 2021-06-24 2021-08-13 军事科学院系统工程研究院网络信息研究所 一种基于绝对-相对学习架构的小样本学习方法和系统
CN113537305A (zh) * 2021-06-29 2021-10-22 复旦大学 一种基于匹配网络少样本学习的图像分类方法
CN113655479A (zh) * 2021-08-16 2021-11-16 西安电子科技大学 基于可变形卷积和双注意力的小样本sar目标分类方法
CN113723562A (zh) * 2021-09-10 2021-11-30 中国计量大学 一种基于小样本学习的胸部x光图像多种疾病分类方法
CN113963165A (zh) * 2021-09-18 2022-01-21 中国科学院信息工程研究所 一种基于自监督学习的小样本图像分类方法及系统
CN114067160A (zh) * 2021-11-22 2022-02-18 重庆邮电大学 基于嵌入平滑图神经网络的小样本遥感图像场景分类方法

Also Published As

Publication number Publication date
CN114943859A (zh) 2022-08-26

Similar Documents

Publication Publication Date Title
CN114943859B (zh) 面向小样本图像分类的任务相关度量学习方法及装置
CN113128558B (zh) 基于浅层空间特征融合与自适应通道筛选的目标检测方法
CN109598220A (zh) 一种基于多元输入多尺度卷积的人数统计方法
CN114332578A (zh) 图像异常检测模型训练方法、图像异常检测方法和装置
CN113065013B (zh) 图像标注模型训练和图像标注方法、系统、设备及介质
CN116147130A (zh) 智能家居控制系统及其方法
CN112926485B (zh) 一种少样本水闸图像分类方法
CN114358197A (zh) 分类模型的训练方法及装置、电子设备、存储介质
CN115631396A (zh) 一种基于知识蒸馏的YOLOv5目标检测方法
CN114780866B (zh) 一种基于时空上下文兴趣学习模型的个性化智能推荐方法
CN117237733A (zh) 一种结合自监督和弱监督学习的乳腺癌全切片图像分类方法
Amendola et al. Data assimilation in the latent space of a neural network
CN115830596A (zh) 基于融合金字塔注意力的遥感图像语义分割方法
CN113627597A (zh) 一种基于通用扰动的对抗样本生成方法及系统
CN111783688B (zh) 一种基于卷积神经网络的遥感图像场景分类方法
CN112381148A (zh) 一种基于随机区域插值的半监督图像分类方法
CN112529057A (zh) 一种基于图卷积网络的图相似性计算方法及装置
CN114723998B (zh) 基于大边界贝叶斯原型学习的小样本图像分类方法及装置
CN116579468A (zh) 基于云系记忆的台风生成预测方法、装置、设备及介质
CN115294381B (zh) 基于特征迁移和正交先验的小样本图像分类方法及装置
CN116741273A (zh) 一种识别空间转录组空间区域和细胞类型的特征学习方法
CN114898136B (zh) 一种基于特征自适应的小样本图像分类方法
CN115953902A (zh) 一种基于多视图时空图卷积网络的交通流预测方法
CN114818945A (zh) 融入类别自适应度量学习的小样本图像分类方法及装置
CN115424012A (zh) 一种基于上下文信息的轻量图像语义分割方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant