CN114639000A - 一种基于跨样本注意力聚合的小样本学习方法和装置 - Google Patents

一种基于跨样本注意力聚合的小样本学习方法和装置 Download PDF

Info

Publication number
CN114639000A
CN114639000A CN202210331296.XA CN202210331296A CN114639000A CN 114639000 A CN114639000 A CN 114639000A CN 202210331296 A CN202210331296 A CN 202210331296A CN 114639000 A CN114639000 A CN 114639000A
Authority
CN
China
Prior art keywords
category
query
aggregation
picture
class
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210331296.XA
Other languages
English (en)
Inventor
曹广
刘鹏
周迪
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Zhejiang Uniview Technologies Co Ltd
Original Assignee
Zhejiang University ZJU
Zhejiang Uniview Technologies Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU, Zhejiang Uniview Technologies Co Ltd filed Critical Zhejiang University ZJU
Priority to CN202210331296.XA priority Critical patent/CN114639000A/zh
Publication of CN114639000A publication Critical patent/CN114639000A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/50Information retrieval; Database structures therefor; File system structures therefor of still image data
    • G06F16/55Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/50Information retrieval; Database structures therefor; File system structures therefor of still image data
    • G06F16/58Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually
    • G06F16/583Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually using metadata automatically derived from the content
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/243Classification techniques relating to the number of classes
    • G06F18/2431Multiple classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Molecular Biology (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Databases & Information Systems (AREA)
  • Library & Information Science (AREA)
  • Probability & Statistics with Applications (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于跨样本注意力聚合的小样本学习方法,包括将将支持集输入至卷积神经网络后压缩至矩阵形式得到多类别特征向量,将多类别特征向量输入至注意力聚合模块得到支持集的多个类别聚合向量;将查询集输入至卷积神经网络后压缩至矩阵形式得到查询特征向量,将查询特征向量输入至注意力聚合模块得到查询聚合向量;基于类别聚合向量、查询聚合向量和真实类别聚合向量通过距离感知概率激活方法得到查询图片的类别预测概率值;类别预测概率值构建交叉熵损失函数,通过交叉熵损失函数优化训练模型中的参数得到类别预测模型;该方法能够清晰区分物体和背景区域,并能够准确提取类别特征。

Description

一种基于跨样本注意力聚合的小样本学习方法和装置
技术领域
本发明属于小样本学习领域,具体涉及一种基于跨样本注意力聚合的小样本学习方法和装置。
背景技术
人类可以从很少的实例中学习新的概念,并拥有很强的泛化能力,这些能力是深度学习算法目前还不具有的,即人类可以从一个或几个实例中学习一个新的概念,但深度学习标准算法则需要更多的实例才能勉强达到相同的能力。深度学习在图像识别、图像分割、自然语言处理等领域获得了较大进展,在实际应用中,庞大的训练样本量往往是高昂的,进而其训练样本量庞大的缺点饱受诟病。
研究者通过各种算法的设计,试图以少量训练样本达到同样的效果,除常见的数据增强、迁移学习外,小样本算法的研究也受到了广泛的关注。目前已有的研究方案,根据出发点和动机的不同,可以大致划分为下列三类:(1)基于数据扩充的小样本学习;(2)基于元学习的小样本学习;(3)基于度量的小样本学习,但目前为止,以原型网络等为代表的常规的小样本学习算法的效果仍有限。
大多数常规的小样本学习算法对于图片的局部和类别实例均视为同等重要,当样本量足够大的时候,特征向量离群点的影响不是很大,但在小样本的情况下,这种做法显然会使模型训练的波动十分剧烈,且离群点将对预测结果造成很大的干扰。
文献Snell,J.,Swersky,K.&Zemel,R.Prototypical networks for few-shotlearning.in Advances in Neural Information Processing Systems.4077-4087(2017).公开的原型网络为小样本学习算法对一张图片提取特征得到特征图后,会对该特征图进行全局平均池化或者拼接等方式形成一个向量来表示该这张图像的特征,随后用所有同类别的特征取平均的方式来代表一个类别“原型”。这种算法的缺点是,在小样本的情况下,背景变化多端,且物体随角度、光线等呈现出的图像差异较大,仅通过平均池化或者拼接的方式显然容易将背景特征或者极端角度下物体的特征过度重视,以至于通过这样的向量来表示一个类别“原型”效果并不理想。
因此,亟需设计一种小样本学习算法避免对所有部位和类别实例一视同仁的做法,能够很好地解决对背景和不重要实例的过度重视,从而能够更好地提取类别特征,提升小样本学习的效果。
发明内容
本发明提供一种基于跨样本注意力聚合的小样本学习方法,该方法能够清晰区分物体和背景区域,并能够准确提取类别特征。
一种基于跨样本注意力聚合的小样本学习方法,包括:
(1)将原始图片分为训练样本集和验证样本集,将训练样本集分为第一支持集和第一查询集,第一支持集包括多类别图片集,每个类别图片集包括多个相同类别图片,第一查询集包括多个相同类别的查询图片;
(2)构建训练模型,训练模型包括支持子模型和查询子模型,其中,支持子模型包括第一卷积神经网络和第一注意力聚合模块,将第一支持集输入至第一卷积神经网络后压缩至矩阵形式得到多类别特征向量,将多类别特征向量分别输入至第一注意力聚合模块得到第一支持集的多个类别聚合向量;
查询子模型包括第二卷积神经网络和第二注意力聚合模块,将查询集输入至第二卷积神经网络后压缩至矩阵形式得到单个类别特征向量,将单个类别特征向量输入至第二注意力聚合模块得到第一查询集的查询聚合向量;
基于多个类别聚合向量、查询聚合向量和真实类别聚合向量通过距离感知概率激活方法得到查询图片的类别预测概率值;
(3)构建损失函数,基于单一类别查询图片类别预测概率值构建交叉熵损失函数,通过交叉熵损失函数优化训练模型中的参数得到类别预测模型;
(4)应用时,将多类别图片集和查询图片集输入至类别预测模型得到每个查询图片的预测类别。
将第一支持集输入至第一卷积神经网络后压缩至矩阵形式得到多个类别特征向量,包括:
将第一支持集中的多个类别图片集输入至第一卷积神经网络得到多类别特征图集,将每个类别特征图集压缩至矩阵形式得到每个类别特征向量。
将每个类别特征向量输入至第一注意力聚合模块得到每个支持类别聚合向量,包括:
对每个类别特征向量求均值得到平均类别特征向量,将平均类别特征向量和每个类别特征向量输入至度量函数,对度量结果进行归一化得到权重向量,将转置的权重向量与每个类别特征向量进行矩阵乘操作得到每个支持类别聚合向量。
对度量结果进行归一化得到权重向量W为:
W=g(M)=softmax(-k*M)
其中,g(·)为归一化函数,M为度量结果,k为超参数,softmax(·)为激活函数。
基于多个类别聚合向量、查询聚合向量和真实类别聚合向量通过距离感知概率激活方法得到查询图片的类别预测概率值
Figure BDA0003573203390000031
为:
Figure BDA0003573203390000032
其中,xq为第一查询集中的第q个查询图片,yq为第一查询集中的第q个查询图片对应的真实类别标签,b(xq)为第一查询集中的第q个查询图片对应的查询聚合向量,a(yq)为第一查询集中的第q个查询图片的真实类别标签对应的第一支持集中的该类别聚合向量,a(k)为第一支持集中第k个类别图片对应的类别特征向量,K为类别个数,d(﹒)为度量函数。
第一支持集中第k个类别图片对应的类别特征向量a(k)为:
Figure BDA0003573203390000033
其中,aggregation(·)为聚合函数,reshape(·)为矩阵函数,
Figure BDA0003573203390000034
为第一卷积神经网络,
Figure BDA0003573203390000035
为第一卷积神经网络中可训练参数,xi为第一支持集中第i个图片,yi为第一支持集中第i个图片对应的类别标签。
第一查询集中的第q个查询图片对应的查询聚合向量b(xq)为:
Figure BDA0003573203390000036
其中,
Figure BDA0003573203390000037
为第二卷积神经网络,
Figure BDA0003573203390000038
为第二卷积神经网络中可训练参数。
通过交叉熵损失函数得到损失值Loss为:
Figure BDA0003573203390000039
通过反向传播算法优化第一、第二卷积神经网络中的可训练参数,使得损失值达到损失阈值,以完成对训练模型的训练得到类别预测模型,训练模型中的参数为第一、第二卷积神经网络中的可训练参数。
还包括对类别预测模型进行验证,具体步骤为:
(1)将验证样本集分为第二支持集和第二查询集,将第二支持集和第二查询集分别输入至类别预测模型得到基于第二查询图片的多个类别预测概率值,将最大类别预测概率值对应的类别作为第二查询图片的预测类别;
(2)重复步骤(1)达到预定的验证次数后,将每次验证得到的第二查询图片的预测类别和对应的真实类别进行比对,比对结果达到准确率阈值,则完成验证,未达到准确率阈值则继续优化训练模型中的参数。
一种基于跨样本注意力聚合的小样本学习装置,包括计算机存储器、计算机处理器以及存储在所述计算机存储器中并可在所述计算机处理器上执行的计算机程序,所述计算机存储器中采用权利要求1~9任一项所述的类别预测模型;
所述计算机处理器执行所述计算机程序时实现以下步骤:
将多类别图片集和查询图片集输入至类别预测模型得到每个查询图片的预测类别。
与现有技术相比,本发明的有益效果为:
(1)本发明通过注意力聚合模块的对输入的特征向量进行平均化操作,度量操作,以及将度量结果概率化的权重向量与特征向量进行矩阵乘计算得到聚合向量,从而避免了离群点的影响,准确的区分了图片的物体区域和背景区域。
(2)本发明将原始图片集分为支持集和查询集,将支持集分为不同的类别通过注意力聚合模块得到不同类别聚合向量,将不同类别聚合向量与查询聚合向量通过损失函数进行训练以得到能够准确预测查询图片的类别。
附图说明
图1为具体实施方式提供的基于跨样本注意力聚合的小样本学习方法框图;
图2为具体实施方式提供的注意力聚合模块示意图;
图3为具体实施方式提供的采用注意力聚合模块聚合示意图;
图4为具体实施方式提供的采用注意力聚合模块的图片效果图片。
具体实施方式
为了使模型提取的类别向量能够兼顾图像不同区域以及不同类别实例的特征,我们设计了跨样本注意力聚合的小样本学习方法,相比传统的小样本学习算法,此算法能够将不同样本和不同区域的特征综合考虑,聚合出更合适的类别向量。
本发明提供了一种基于跨样本注意力聚合的小样本学习方法,如图1所示,具体步骤为:
(1)支持集以2-way 3-shot(即共2个类别,每类3个图片)为例,查询集以一个图片为例。
(2)构建训练模型,训练模型包括支持子模型和查询子模型,其中,支持子模型包括第一卷积神经网络和第一注意力聚合模块,查询子模型包括第二卷积神经网络和第二注意力聚合模块。
将支持集输入至第一卷积神经网络得到2个特征组,每个特征组对应一个类别,每个特征组包括3个特征图,每个特征图的维度为5*5*64(即特征图的长和宽为5);将查询集输入至第二卷积神经网络得到一个维度为5*5*64的特征图,其中,第一卷积神经网络和第二卷积神经网络参数共享,即在训练后参数值相同,因此第一卷积神经网络和第二卷积神经网络均表示为CNN。
将支持集的特征图和查询集的特征图分别进行reshape操作,即压缩至矩阵形式,支持集需要按照类别分开操作,因此支持集的特征图经过reshape操作得到2组75*64的矩阵,即多类别特征向量,其中,每一行均为一个维度64的特征向量,共75个不同的向量,用于表示不同类别图片的所有不同区域的特征向量,同样地,查询集的图片则形成了25个不同区域的维度为64的单个类别特征向量,随后,将支持集的多类别特征向量和查询集的单个类别特征向量分别通过第一、二注意力聚合模块形成两个2*64类别聚合向量和一个1*64查询聚合向量,共三个聚合向量。
其中,类别特征向量a(k)为:
Figure BDA0003573203390000051
其中,aggregation(﹒)为聚合函数,reshape(﹒)为矩阵函数,
Figure BDA0003573203390000052
为第一卷积神经网络,
Figure BDA0003573203390000061
为第一卷积神经网络中可训练参数,xi为第一支持集中第i个图片,yi为第一支持集中第i个图片对应的类别标签。
查询聚合向量b(xq)为:
Figure BDA0003573203390000062
其中,
Figure BDA0003573203390000063
为第二卷积神经网络,
Figure BDA0003573203390000064
为第二卷积神经网络中可训练参数。
基于多个类别聚合向量、查询聚合向量和真实类别聚合向量通过距离感知概率激活方法得到查询图片的类别预测概率值
Figure BDA0003573203390000065
为:
Figure BDA0003573203390000066
其中,xq为第一查询集中的第q个查询图片,yq为第一查询集中的第q个查询图片对应的真实类别标签,b(xq)为第一查询集中的第q个查询图片对应的查询聚合向量,a(yq)为第一查询集中的第q个查询图片的真实类别标签对应的第一支持集中的该类别的聚合向量,a(k)为第一支持集中第k个类别图片对应的类别特征向量,K为类别个数,d(·)为度量函数,本算法中选择欧氏距离度量函数。
(3)通过交叉熵损失函数优化训练模型中的参数得到类别预测模型;接下来可通过以下损失函数Loss为:
Figure BDA0003573203390000067
通过反向传播算法优化第一、第二卷积神经网络中的可训练参数,使得损失值达到损失阈值,以完成对训练模型的训练得到类别预测模型,训练模型中的参数
Figure BDA0003573203390000068
为第一、第二卷积神经网络中的可训练参数。
其中,将支持集的多类别特征向量和查询集的单个类别特征向量分别通过第一、二注意力聚合模块形成两个类别聚合向量和一个查询聚合向量,共三个聚合向量,第一、二注意力聚合模块的聚合步骤相同,如图2所示,具体操作步骤如下:
将x个64维度的特征向量,如果x为75时则为类别特征向量,如果x为25时则为查询特征向量,为方便起见,将x个64维度的特征向量定义为特征向量R,将特征向量R输入至注意力聚合模块,将特征向量R求平均得到一个64维的平均向量,然后,采用度量函数(本发明使用的度量函数为欧式距离度量函数)将特征向量R和平均向量进行计算得到度量结果M,将度量结果M归一化得到权重向量W,将权重向量W转置后与特征向量R进行矩阵乘积操作,得到64维的聚合向量。
对度量结果进行归一化得到权重向量W为:
W=g(M)=softmax(-k*M)
其中,g(·)为归一化函数,M为度量结果,k为超参数,softmax(·)为激活函数。
如图3所示,图中蓝色圆点代表不同样本的特征向量,绿色三角形表示所有这些蓝色圆点的平均向量,这是在原型网络中采用的类原型的获取方式,可见由于离群点的引入导致其偏离了左下角数量占比较多的样本点的中心;而红色五角星则是根据注意力机制所聚合产生的向量,显然它离左下角数量占比较多的样本点的中心更接近,表明其受离群点的影响较小。
采用上述算法在文献Vinyals,O.,Blundell,C.,Lillicrap,T.&Wierstra,D.Matching networks for one shot learning.in Advances in neural informationprocessing systems.3630-3638(2016).中提供的小样本数据集miniImagenet上进行训练,并对其中的注意力效果进行可视化,结果呈现于图4中。从图4中发现,注意力聚合模块对于图像中的物体具有良好的聚焦效果(红色部分),而背景部分则显现出了更低的注意程度(蓝色部分)。因此本文所述跨样本注意力聚合算法能够达到设想中的提升对物体的注重程度的效果。
将原始图片集分为训练样本集和验证样本集:
对训练模型进行训练,具体步骤为:
从训练样本集中抽取K个类别的图像,每个类别N个样本,构成第一支持集S1,共K*N个样本;再从训练集中的K个类别的剩余样本中每个类别采样T个样本,组成第一查询集Q1,共K*T个样本;通过第一支持集S1和第一查询集Q1对训练模型进行反向传播训练;将所有样本放回训练样本集,并重复执行上述步骤使得损失值达到损失阈值或达到规定循环执行次数,以完成对训练模型的训练得到类别预测模型;
对类别预测模型进行验证,具体步骤为:
从验证样本集中抽取K个类别的图像,每个类别N个样本,构成第二支持集S2,共K*N个样本;再从验证样本集中的K个类别的剩余样本中每个类别采样T个样本,组成第二查询集Q2,共K*T个样本;将第二支持集S2和第二查询集Q2对分别输入至类别预测模型得到基于第二查询图片的多个类别预测概率值,将最大类别预测概率值对应的类别,或直接选择查询样本聚合向量至各类别聚合向量的距离最近的类别作为该查询样本的预测类别。多次验证后将每次验证得到的第二查询图片的预测类别和对应的真实类别进行比对,比对结果达到准确率阈值,则完成验证,未达到准确率阈值则继续优化训练模型中的参数。

Claims (10)

1.一种基于跨样本注意力聚合的小样本学习方法,其特征在于,包括:
(1)将原始图片集分为训练样本集和验证样本集,将训练样本集分为第一支持集和第一查询集,第一支持集包括多类别图片集,每个类别图片集包括多个相同类别图片,第一查询集包括多个相同类别的查询图片;
(2)构建训练模型,训练模型包括支持子模型和查询子模型,其中,支持子模型包括第一卷积神经网络和第一注意力聚合模块,将第一支持集输入至第一卷积神经网络后压缩至矩阵形式得到多类别特征向量,将多类别特征向量分别输入至第一注意力聚合模块得到第一支持集的多个类别聚合向量;
查询子模型包括第二卷积神经网络和第二注意力聚合模块,将查询集输入至第二卷积神经网络后压缩至矩阵形式得到单个类别特征向量,将单个类别特征向量输入至第二注意力聚合模块得到第一查询集的查询聚合向量;
基于多个类别聚合向量、查询聚合向量和真实类别聚合向量通过距离感知概率激活方法得到查询图片的类别预测概率值;
(3)构建损失函数,基于单一类别查询图片类别预测概率值构建交叉熵损失函数,通过交叉熵损失函数优化训练模型中的参数得到类别预测模型;
(4)应用时,将多类别图片集和查询图片集输入至类别预测模型得到每个查询图片的预测类别。
2.根据权利要求1所述的基于跨样本注意力聚合的小样本学习方法,其特征在于,将第一支持集输入至第一卷积神经网络后压缩至矩阵形式得到多个类别特征向量,包括:
将第一支持集中的多个类别图片集输入至第一卷积神经网络得到多类别特征图集,将每个类别特征图集压缩至矩阵形式得到每个类别特征向量。
3.根据权利要求1所述的基于跨样本注意力聚合的小样本学习方法,其特征在于,将每个类别特征向量输入至第一注意力聚合模块得到每个支持类别聚合向量,包括:
对每个类别特征向量求均值得到平均类别特征向量,将平均类别特征向量和每个类别特征向量输入至度量函数,对度量结果进行归一化得到权重向量,将转置的权重向量与每个类别特征向量进行矩阵乘操作得到每个支持类别聚合向量。
4.根据权利要求3所述的基于跨样本注意力聚合的小样本学习方法,其特征在于,对度量结果进行归一化得到权重向量W为:
W=g(M)=softmax(-k*M)
其中,g(﹒)为归一化函数,M为度量结果,k为超参数,softmax(﹒)为激活函数。
5.根据权利要求1所述的基于跨样本注意力聚合的小样本学习方法,其特征在于,基于多个类别聚合向量、查询聚合向量和真实类别聚合向量通过距离感知概率激活方法得到查询图片的类别预测概率值
Figure FDA0003573203380000021
为:
Figure FDA0003573203380000022
其中,xq为第一查询集中的第q个查询图片,yq为第一查询集中的第q个查询图片对应的真实类别标签,b(xq)为第一查询集中的第q个查询图片对应的查询聚合向量,a(yq)为第一查询集中的第q个查询图片的真实类别标签所对应的第一支持集中的该类别聚合向量,a(k)为第一支持集中第k个类别图片对应的类别特征向量,K为类别个数,d(﹒)为度量函数,
Figure FDA0003573203380000023
为训练模型中的参数。
6.根据权利要求5所述的基于跨样本注意力聚合的小样本学习方法,其特征在于,类别特征向量a(k)为:
Figure FDA0003573203380000024
其中,aggregation(﹒)为聚合函数,reshape(﹒)为矩阵函数,
Figure FDA0003573203380000025
为第一卷积神经网络,
Figure FDA0003573203380000026
为第一卷积神经网络中可训练参数,xi为第一支持集中第i个图片,yi为第一支持集中第i个图片对应的类别标签。
7.根据权利要求6所述的基于跨样本注意力聚合的小样本学习方法,其特征在于,查询聚合向量b(xq)为:
Figure FDA0003573203380000027
其中,
Figure FDA0003573203380000028
为第二卷积神经网络,
Figure FDA0003573203380000029
为第二卷积神经网络中可训练参数。
8.根据权利要求1所述的基于跨样本注意力聚合的小样本学习方法,其特征在于,通过交叉熵损失函数得到损失值Loss为:
Figure FDA0003573203380000031
通过反向传播算法优化第一、第二卷积神经网络中的可训练参数,使得损失值达到损失阈值,以完成对训练模型的训练得到类别预测模型,训练模型中的参数
Figure FDA0003573203380000032
为第一、第二卷积神经网络中的可训练参数。
9.根据权利要求1所述的基于跨样本注意力聚合的小样本学习方法,其特征在于,还包括对类别预测模型进行验证,具体步骤为:
(1)将验证样本集分为第二支持集和第二查询集,将第二支持集和第二查询集分别输入至类别预测模型得到基于第二查询图片的多个类别预测概率值,将最大类别预测概率值对应的类别作为第二查询图片的预测类别;
(2)重复步骤(1)达到预定的验证次数后,将每次验证得到的第二查询图片的预测类别和对应的真实类别进行比对,比对结果达到准确率阈值,则完成验证,未达到准确率阈值则继续优化训练模型中的参数。
10.一种基于跨样本注意力聚合的小样本学习装置,包括计算机存储器、计算机处理器以及存储在所述计算机存储器中并可在所述计算机处理器上执行的计算机程序,其特征在于,所述计算机存储器中采用权利要求1~9任一项所述的类别预测模型;
所述计算机处理器执行所述计算机程序时实现以下步骤:
将多类别图片集和查询图片集输入至类别预测模型得到每个查询图片的预测类别。
CN202210331296.XA 2022-03-30 2022-03-30 一种基于跨样本注意力聚合的小样本学习方法和装置 Pending CN114639000A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210331296.XA CN114639000A (zh) 2022-03-30 2022-03-30 一种基于跨样本注意力聚合的小样本学习方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210331296.XA CN114639000A (zh) 2022-03-30 2022-03-30 一种基于跨样本注意力聚合的小样本学习方法和装置

Publications (1)

Publication Number Publication Date
CN114639000A true CN114639000A (zh) 2022-06-17

Family

ID=81951129

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210331296.XA Pending CN114639000A (zh) 2022-03-30 2022-03-30 一种基于跨样本注意力聚合的小样本学习方法和装置

Country Status (1)

Country Link
CN (1) CN114639000A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115019175A (zh) * 2022-06-27 2022-09-06 华南农业大学 一种基于迁移元学习的害虫识别方法
CN117058470A (zh) * 2023-10-12 2023-11-14 宁德思客琦智能装备有限公司 一种基于小样本学习的三维点云分类的方法和系统

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115019175A (zh) * 2022-06-27 2022-09-06 华南农业大学 一种基于迁移元学习的害虫识别方法
CN117058470A (zh) * 2023-10-12 2023-11-14 宁德思客琦智能装备有限公司 一种基于小样本学习的三维点云分类的方法和系统
CN117058470B (zh) * 2023-10-12 2024-01-26 宁德思客琦智能装备有限公司 一种基于小样本学习的三维点云分类的方法和系统

Similar Documents

Publication Publication Date Title
Liu et al. Rankiqa: Learning from rankings for no-reference image quality assessment
CN110163258B (zh) 一种基于语义属性注意力重分配机制的零样本学习方法及系统
CN107977932B (zh) 一种基于可鉴别属性约束生成对抗网络的人脸图像超分辨率重建方法
CN109063724B (zh) 一种增强型生成式对抗网络以及目标样本识别方法
WO2020228525A1 (zh) 地点识别及其模型训练的方法和装置以及电子设备
US9400919B2 (en) Learning deep face representation
CN114202672A (zh) 一种基于注意力机制的小目标检测方法
CN109543602B (zh) 一种基于多视角图像特征分解的行人再识别方法
CN109978041B (zh) 一种基于交替更新卷积神经网络的高光谱图像分类方法
CN111738363B (zh) 基于改进的3d cnn网络的阿尔茨海默病分类方法
CN114639000A (zh) 一种基于跨样本注意力聚合的小样本学习方法和装置
Tscherepanow TopoART: A topology learning hierarchical ART network
CN107169117B (zh) 一种基于自动编码器和dtw的手绘图人体运动检索方法
CN110309835B (zh) 一种图像局部特征提取方法及装置
CN111652273B (zh) 一种基于深度学习的rgb-d图像分类方法
CN110879982A (zh) 一种人群计数系统及方法
CN110188827A (zh) 一种基于卷积神经网络和递归自动编码器模型的场景识别方法
CN113066065A (zh) 无参考图像质量检测方法、系统、终端及介质
CN115311502A (zh) 基于多尺度双流架构的遥感图像小样本场景分类方法
CN116091823A (zh) 一种基于快速分组残差模块的单特征无锚框目标检测方法
CN114780767A (zh) 一种基于深度卷积神经网络的大规模图像检索方法及系统
CN114492755A (zh) 基于知识蒸馏的目标检测模型压缩方法
CN112329662B (zh) 基于无监督学习的多视角显著性估计方法
CN116113952A (zh) 用于图像的属于分布内度量的分布之间的距离
CN117011655A (zh) 基于自适应区域选择特征融合方法、目标跟踪方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination