CN114067155A - 基于元学习的图像分类方法、装置、产品及存储介质 - Google Patents

基于元学习的图像分类方法、装置、产品及存储介质 Download PDF

Info

Publication number
CN114067155A
CN114067155A CN202111352723.4A CN202111352723A CN114067155A CN 114067155 A CN114067155 A CN 114067155A CN 202111352723 A CN202111352723 A CN 202111352723A CN 114067155 A CN114067155 A CN 114067155A
Authority
CN
China
Prior art keywords
task
training
clustering
model
feature vector
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111352723.4A
Other languages
English (en)
Inventor
刘璇
曾兴隆
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hunan University
Original Assignee
Hunan University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hunan University filed Critical Hunan University
Priority to CN202111352723.4A priority Critical patent/CN114067155A/zh
Publication of CN114067155A publication Critical patent/CN114067155A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • G06F18/232Non-hierarchical techniques
    • G06F18/2321Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
    • G06F18/23213Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Molecular Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种基于元学习的图像分类方法、装置、产品及存储介质,对于当前的任务
Figure DDA0003356392860000011
根据输入的训练样本
Figure DDA0003356392860000012
训练一个样本特征提取器
Figure DDA0003356392860000013
以及任务特征提取器F,获取对应的任务特征vi。将vi输入到聚类网络层,通过最小距离匹配以及在线聚类的方式更新对应的聚类特征向量集合
Figure DDA0003356392860000014
并输出最适配当前任务的聚类特征向量cω。基于任务特征评估网络σ对任务特征向量cω及当前任务特征vi进行相关性评估,输出对应的评估系数α,并基于该系数计算适用于当前任务的表征向量ωout。基于ωout对模型的参数进行调制,得到最优的模型初始化参数θnew。模型基于该初始化参数θnew开始训练更新。本发明能够在不进行人为干扰的情况下完成任务先验知识的更新以及相关性适配并高效复用以加速新的图像分类任务的学习。

Description

基于元学习的图像分类方法、装置、产品及存储介质
技术领域
本发明涉及元学习和图像识别领域,特别是一种基于元学习的图像分类方法、装置、产品及存储介质。
背景技术
近年来,深度学习领域的发展让机器学习步入一个新的阶段,深度神经网络的发展也使得机器学习模型能够获得优异的性能。然而,现有的机器学习模型大多数需要大量被标记的训练样本以及大量的训练时间。然而,对应于现实生活中,样本的标签收集是困难的,甚至有时候样本自身的丰富程度也存在差异,有些任务可能只有少量的训练样本。同时,对于时效性较强的模型训练,也不允许使用大量的时间进行训练。针对上述问题,一种解决方法就是通过知识迁移的方式,借助其他任务的学习经验,来增加新任务的学习效率,同时减少训练样本需求量。元学习旨在通过复用在过去学习过的任务中的学习经验,进而加速新任务的学习进程。
现有的元学习算法主要分为基于度量,基于确定模型以及基于梯度更新的元学习算法,本发明主要关注基于梯度更新的元学习算法。现有的基于梯度更新的元学习算法通过为当前任务生成最优模型初始化参数的方式来加速新任务的学习,由于上述算法所产生的模型初始化参数只考虑当前模型过去所学习过的所有任务,产生使得对于过去所有的任务都能快速适应的模型初始化参数。上述方式主要存在两大缺陷:①模型并没有考虑新任务是否与过去学习的任务相关,而初始化参数只与过去学习的任务相关。②模型过去学习的任务可能自身存在分布的差异,生成对应于所有任务都最有的模型初始化参数可能是使得模型缺乏特异性,导致在新任务上表现不佳。因此,考虑新任务与过去学习任务的相关性以及差异性,将是提高元学习算法学习效率的一个高效手段。
问题定义如图1所示,假设有一系列任务
Figure BDA0003356392840000011
从任务概率分布
Figure BDA0003356392840000012
中随机抽样而来,每个任务
Figure BDA0003356392840000013
有多个样本
Figure BDA0003356392840000014
每个任务的样本通常分为两个部分称为训练集
Figure BDA0003356392840000015
和测试集
Figure BDA0003356392840000016
在元学习场景中,模型训练通常分为元训练阶段和元测试阶段。给定一个元学习器f和模型参数θ。元训练的目标是在训练任务上训练模型f来存储先验知识,并通过使用先验知识使模型学习得更快,而元测试则是为了验证模型的拟合效果以及泛化能力。在小样本图像识别场景下,一个任务通常是指对多个不同类别的图像进行识别,元学习的目标就是在经过一系列图像识别任务的元训练阶段后之后,使得元学习模型能够有足够强的泛化能力,使其能够高效学习新任务。
现有的基于梯度的元学习算法没有考虑新任务与过去学习过的任务群里的分布差异以及相关性导致算法的性能下降。尽管已经有几项研究工作考虑新任务自身的分布对模型初始化参数的影响。但是根据人类的学习经验,先验知识与当前任务的相关性以及先验知识自身的内部差异是应该被考虑在内的。
在现有的相关文献中,公开号为CN111539448A的发明专利公开了一种基于元学习的少样本图像分类方法,该方法考虑对所有的训练集数据进行k-means聚类,并基于聚类的结果对待适应的新任务进行适配。但是,该方法考虑需要一次性拿到所有的训练任务数据。而在实际的模型训练中,元学习的任务一般为序列到达,模型一次性只能接触一个任务,并在不断的学习过程中更新、储存并复用有效的先验知识。公开号为CN111724083A的发明专利公开了一种金融风险识别模型的训练方法、装置、计算机设备及介质。该方法通过对训练数据进行聚类并基于新任务与不同类别的训练数据的距离矩阵来确定并训练多个对应的分类器,旨在为不同类别的任务类别提供多个特定的分类器,并用于金融风险的快速预测与评估。多个分类器的方法使得新的评测数据能够使用更加有效的分类器,但这同时也限制了模型的可扩展性,如果存在新的预测任务与预处理时使用的任务存在较大差异,那么基于预训练得到的分类器可能就不适用于新任务。
发明内容
本发明所要解决的技术问题是,针对现有技术不足,提供一种基于元学习的图像分类方法、装置、产品及存储介质,能够适用于任何基于梯度更新的元学习模型,并基于上述元学习模型快速适应新的图像分类任务并得到优异的图像分类效果。
为解决上述技术问题,本发明所采用的技术方案是:一种基于元学习的图像分类方法,包括以下步骤:
S1、对于当前的包含Ni个样本的任务
Figure BDA0003356392840000021
Figure BDA0003356392840000022
分别表示输入的当前任务的每一个样本及其标签,将当前任务样本分为训练样本
Figure BDA0003356392840000023
以及测试样本
Figure BDA0003356392840000024
其中,
Figure BDA0003356392840000025
根据输入的训练样本
Figure BDA0003356392840000026
训练样本特征提取器
Figure BDA0003356392840000027
以及任务特征提取器F,并基于样本特征提取器
Figure BDA0003356392840000028
以及任务特征提取器F获取当前任务的任务特征vi;其中,所述训练样本为图像;
S2、利用任务特征vi,更新对应的聚类特征向量集合
Figure BDA0003356392840000031
并输出最适配当前任务的聚类特征向量cω;其中,K表示聚类网络层中聚类的蔟数,cm表示聚类网络层中的聚类特征向量;
S3、基于任务特征评估网络σ对聚类特征向量cω及当前任务特征vi进行相关性评估,输出对应的评估系数α,并基于该评估系数α输出适用于当前任务的表征向量ωout;所述任务特征评估网络σ包括多个级联的连接层;
S4、基于任务的表征向量ωout对卷积神经网络模型的参数进行更新,得到最优的模型初始化参数θnew
S5、基于所述初始化参数θnew训练并更新卷积神经网络模型参数,得到图像分类模型。
本发明具有如下优势:
(1)步骤S1的任务特征提取器基于自动编码器解码器将当前任务的多个样本特征整合为一个任务特征向量,为模型提供一个可靠且维度有限的任务表征,避免了网络输入参数过大带来的训练复杂的影响。
(2)步骤S2基于与当前任务最相似的先验知识表征进行了更新,同时输出更新后的聚类特征向量。保证了模型中输出的先验知识与当前任务具备更高的相关性。同时,聚类特征的在线扩充技术的使用也使得模型在遇到其他分布的任务时,能够自动扩展聚类个数,既保证了模型训练的稳定性,也保证了参数更新的可靠性。
(3)步骤S3评估了当前任务的特征向量与任务聚类层输出的聚类特征向量的相关性,并根据结果生成先验表征向量,从而进一步保证了先验知识的可靠性,同时也保证了任务聚类层所保存的任务聚类特征能够被准确更新。
(4)步骤S4通过使用先验表征向量对模型初始化参数进行更新,保证了更新后的模型初始化参数能够更加适用于当前任务,保证其能够在更少的训练次数下获得更好的分类效果。
(5)步骤S5基于调制后的模型初始化参数完成模型的训练,因为算法没有明确要求模型的具体实现以及网络构造,因此使得本发明能够适用于任何基于梯度更新的元学习模型。
步骤S1中,当前任务特征vi的表达式为:
Figure BDA0003356392840000041
其中,gi,j为对样本特征向量
Figure BDA0003356392840000042
进行编码得到的编码向量:
Figure BDA0003356392840000043
其中,RNNenc表示以当前任务样本特征向量
Figure BDA0003356392840000044
作为输入的循环神经网络;训练期间,将训练样本
Figure BDA0003356392840000045
输入到样本特征提取器
Figure BDA0003356392840000046
Figure BDA0003356392840000047
的输出即为对应的样本特征向量集合
Figure BDA0003356392840000048
Figure BDA0003356392840000049
为当前任务的第j个训练样本。
通过在S1阶段对任务的样本进行特征提取,可以将当前任务的输入样本,即图片从二维矩阵提取成一维的向量,可以在尽可能保证数据信息完整的情况下解决因为输入维度过大导致的模型训练复杂的问题,同时,任务特征vi被表示为多个样本特征的均值,可以最大可能的降低因为样本输入顺序带来的误差,也使得所提取的任务特征更加具备可靠性。
步骤S2的具体实现过程包括:
1)对比当前任务特征向量vi与所有存储的聚类特征向量集合
Figure BDA00033563928400000410
的欧式距离d(cj,vi),并选出距离最小的聚类特征向量cω
2)若步骤1)中聚类特征向量cω与当前任务的任务特征向量vi之间的距离d(cω,vi)大于设定的距离阈值dmax,同时,当前次训练对应的训练损失Lnow与上一次训练的训练损失Llast变化大于设定的变化阈值γ,即:d(cω,vi)>dmax and Lnow/Llast>γ,则聚类中心个数加1,并添加一个新的聚类特征向量cK+1,利用公式cK+1=vi初始化新的聚类特征向量,并将ck+1加入到聚类特征向量集合
Figure BDA00033563928400000411
中,返回步骤1);若上述条件不满足,则利用公式
Figure BDA00033563928400000412
N'ω=Nω+1获取更新的聚类特征向量cω',Nω代表当前被选择的聚类特征向量的更新次数,直接输出c'ω
通过步骤S2,可以针对性的选择和当前任务最相关的任务聚类特征并单独更新对应的聚类特征,保证了每个任务聚类特征都能够存储对应于不同任务分布的先验知识,因此也保证了在后续的步骤中能够得到更加适用于当前任务的先验表征向量;另一方面,步骤2中的2)提供了一个错误检测与恢复机制让模型能够自主的判断当前任务的分布与过去学习过的任务分布之间的差异性,使得模型在遇到属于未知任务分布的新任务时能够自动化的增加新的聚类特征向量,同时也保证现有保存的聚类信息不被干扰,保证了模型训练的稳定性以及先验知识的可靠性。
步骤S3中,适用于当前任务的表征向量ωout的表达式为:
ωout=α·cω+(1-α)·vi
其中,
Figure BDA0003356392840000051
σ表示由全连接网络构成的评估网络,
Figure BDA0003356392840000052
表示将cω以及vi合并为一个向量并作为评估网络的输入参数。
步骤S3基于自动生成的相关性系数α中对当前任务特征vi以及被选择的任务聚类特征cω进行线性调制。使得当前任务可以自主控制先验表征ωout中先验知识与当前任务信息之间的权重,进一步保证了当前任务能够基于与自身最相关的知识产生可靠的模型初始化参数,进而有效减少模型的训练时间。
步骤S4中,
Figure BDA0003356392840000053
其中,
Figure BDA0003356392840000054
M表示模型参数的数量;gi为门控函数,由一组全连接网络层构成;φg表示门控函数的网络参数。
步骤S4使用了一系列门控函数,用以将S3步骤获得的表征向量ωout调制为不同的形状,用以匹配模型中不同形状的网络参数;同时,门控函数的参数也是基于梯度下降更新,这也使得同一个任务表征向量对不同的模型参数有各异的调制效果,保证了每一个模型参数都能有适用于当前任务的初始化内容。
步骤S5中,模型训练时基于梯度下降对模型的参数进行更新,其中,更新过程中使用的损失函数为
Figure BDA0003356392840000055
上述损失函数包含三部分,分别是当前任务的损失函数
Figure BDA0003356392840000056
Figure BDA0003356392840000057
在图像分类任务中设定为交叉熵损失函数
Figure BDA0003356392840000058
Figure BDA0003356392840000059
为步骤S1中的任务特征提取网络的损失函数,即
Figure BDA0003356392840000061
Figure BDA0003356392840000062
为正则化损失函数,即
Figure BDA0003356392840000063
其中θ代表模型的参数。
步骤S5中,模型基于梯度下降对调制后的初始化参数进行更新,由于模型参数并没有明确给定,本方法适用于任意基于梯度更新的元学习算法。这也使得本发明不仅能使用于图像分类任务,针对不同的学习目标,仅需要修改对应的任务损失函数,就可以完成模型网络的训练准备工作。
本发明还提供了一种计算机装置,包括存储器、处理器及存储在存储器上的计算机程序;所述处理器执行所述计算机程序,以实现本发明方法的步骤;
本发明还提供了一种计算机可读存储介质,其上存储有计算机程序/指令;其特征在于,所述计算机程序/指令被处理器执行时实现本发明方法的步骤。
本发明还提供了一种计算机程序产品,包括计算机程序/指令;其特征在于,该计算机程序/指令被处理器执行时实现本发明方法的步骤;
与现有技术相比,本发明所具有的有益效果为:
(1)本发明能够不依赖额外的信息对先验知识进行准确区分、更新并高效复用。因为来源于不同任务分布的图像分类任务之间共享较少的先验知识,通过在线聚类对来自不同任务分布的任务先验知识进行聚类区分,可以只筛选相关性高的先验知识来帮助新任务快速学习,进而提高新任务的学习效率。
(2)本发明所计算的模型初始化参数综合考虑了新任务与先验知识之间的相关性并给出与当前任务最相关的先验知识,因此本发明能够为基于梯度的元学习算法提供更加可靠的模型初始化参数,使得模型能够在新任务中快速学习,进而在有限的更新次数下,获得更加优异的图像分类效果。
(3)本发明的模型对于未出现的任务分布所产生的任务,即新任务的任务分布与训练任务的分布存在较大差距的情况下,能够自主识别并扩充保存的先验知识聚簇,进而高效复用相关性强的先验知识,减少错误先验知识带来的负面影响。
(4)本发明的模型能够在训练任务序列到达的情况下完成先验知识的存储、更新及复用,进而高效完成新任务的快速适应。因为本发明的模型基于在线聚类以及动态扩展的方式处理任务的先验知识,符合现实任务的相关设定;
附图说明
图1为本发明的元训练任务以及元测试任务示意图;
图2为基于元学习的图像分类方法的执行流程示意图。
具体实施方式
本发明的小样本图像分类场景下的通用的基于在线聚类的元学习方法包括以下步骤:
如图2所示,本发明主要由五个阶段组成。在特征提取阶段,通过将输入的图片进行特征提取,将多个类别中的若干张图片编码为任务特征向量,该过程能够通过梯度更新的方式对特征提取器进行优化,使得模型能够逐渐提高所提取任务特征的可靠性。在任务特征聚类阶段,通过将提取的任务特征输入到任务聚类层并找出与当前任务最相近的任务聚类特征并更新该聚类特征,该步骤能够输出与当前任务最具备相关性的过往任务分布的任务聚类特征。在相似性评估阶段,本发明使用一个评估网络生成任务聚类特征与当前任务特征的相关性系数,并基于这个系数计算模型的先验表征向量,即任务特征向量以及聚类特征向量的加权和,基于上述操作,本发明同时考究了当前的任务以及过去学习过的任务的重要性,能让模型的初始化参数具备更多特异性。在参数调制阶段,本发明通过一系列门控函数,将先验表征向量调制为与模型参数相同的维度,并将调制后的先验表征向量和对应的模型参数进行相乘,最后得到新的模型初始化参数,该参数能够更加适用于当前的任务。模型训练阶段将调制后的模型参数传入具体的模型并基于梯度下降开始训练,使得模型能够在有限次数的更新后,得到优异的分类结果。
1)特征提取
特征提取主要分为为两个步骤;样本特征提取以及任务特征提取。一个任务通常含有多个不同类别的样本,在样本特征提取阶段,我们采用样本特征提取器
Figure BDA0003356392840000071
完成样本特征提取,
Figure BDA0003356392840000072
是基于聚合器嵌入的方式实现,具体实现为一个由两个卷积层和两个全连接层组成的块。而在任务特征提取阶段,我们采用的是自动编解码器F,F中的自动编码器和自动解码器都是基于RNN实现,通过上述两层特征提取器,就可以完成从任务样本输入到任务特征的输出。
不失一般性,考虑当前为第i个任务
Figure BDA0003356392840000073
该任务有多个训练样本
Figure BDA0003356392840000074
为了保证训练结果的可靠性,我们只使用训练样本
Figure BDA0003356392840000075
完成特征提取。
首先,将每一个训练样本
Figure BDA0003356392840000076
按序输入到样本特征提取器
Figure BDA0003356392840000077
并输出对应的样本特征集合
Figure BDA0003356392840000078
Figure BDA0003356392840000079
一般是一个1×n的向量,在小样本图像分类中,n的值设定为128。然后,将获取的样本特征集合
Figure BDA0003356392840000081
输入到任务特征提取器F中,对应的输出对应的任务特征向量gi,j,为了避免应为样本的输入顺序对特征提取网络训练带来的负面影响,考虑将样本特征进行随机打乱,并输入到任务特征网络,得到多个任务特征向量,具体的计算方式为:
Figure BDA0003356392840000082
最后,将多个任务样本特征gi,j的均值vi作为当前按任务的特征向量输出:
Figure BDA0003356392840000083
我们基于梯度下降的方式来训练上述任务特征提取器。为了保证所提取的任务特征能够可靠表征当前任务,我们将获取的任务特征通过解码器进行还原为Fdec(gi,j),并根据解码的样本特征与原始任务样本特征的差值来训练特征提取器网络,具体的损失函数为:
Figure BDA0003356392840000084
2)任务特征聚类
Ⅰ、将从任务特征提取器所生成的当前任务的任务特征向量vi输入到任务聚类层,任务聚类层通过对比当前任务特征向量vi与所有存储的聚类特征向量集合
Figure BDA0003356392840000085
基于欧式距离进行距离d(cj,vi)比对,并选出对应的距离最小的聚类特征向量cω,即
Figure BDA0003356392840000086
Ⅱ、为了保证所选择的聚类特征向量的可靠性,本方法考虑将当前所求得的最小距离与最近两批次的训练损失变化作为评估条件,若步骤Ⅰ中对应距离最小的聚类特征向量cω与当前任务的任务特征向量vi之间的距离d(cω,vi)大于设定的距离阈值dmax,同时,当前批次对应的训练损失Lnow与上一批次的训练损失Llast变化大于设定的变化阈值γ,即:
d(cj,vi)>dmax and Lnow/Llast>γ;
执行聚类层的自动扩充操作,具体的执行方式如下:
①设置聚类中心个数K=K+1,并添加一个新的簇向量cK+1并初始化新的簇
cK+1=vi,NK+1=1;
②将ck+1添加到聚类特征向量集合
Figure BDA0003356392840000091
重置模型并将存储参数从缓存重新加载卷积神经网络模型并返回步骤Ⅰ。
当上述条件不满足时,则对所选择的任务特征向量ω进行更新操作,具体方式如下:
Figure BDA0003356392840000092
其中Nω代表当前被选择的聚类特征向量cω的更新次数。最后,将更新后的任务特征向量c'ω作为任务聚类层的输出。
在小样本图像分类场景下,聚类层的初始聚类蔟数被设定为4,而距离阈值dmax以及损失变化阈值γ被设定为0.80以及1.25。
3)相关性评估
相关性评估主要分为两个阶段,首先时根据任务特征评估网络σ获取对应的相关性系数α,然后根据输出的系数对两个特征进行线性组合,具体如下:
Ⅰ、本发明使用一个任务特征评估网络σ对任务特征向量cω及当前任务特征vi进行相关性评估。评估网络σ同时接受当前任务特征vi以及任务聚类层输出cω作为网络的输入参数,并输出对应的相关性系数α,具体实现为:
Figure BDA0003356392840000093
其中σ基于一个全连接网络实现,
Figure BDA0003356392840000094
表示将任务特征vi以及上层任务聚类层输出cω合并为一个张量并输入任务特征评估网络σ,输出的α是一个值位于0–1区间内的向量。
Ⅱ、基于特征评估网络σ所输出的相关性系数α,通过计算得到当前任务的先验表征向量ωout,具体计算方式为:
ωout=α·cω+(1-α)·vi
通过上述步骤,我们可以获得一个同时关注过去先验知识以及当前新任务的特征的任务表征向量ωout,基于该向量对模型参数调制,就能产生对当前任务更加适合的模型初始化参数。
4)参数调制
参数调制阶段分为两个步骤,首先,由于模型本身的结构性质,模型内可能会存在多种不同形状,甚至不同维度的网络参数,为了将任务表征向量ωout作用于每一个可训练的模型参数,本发明基于一系列门控函数g将ωout调制为对应维度的调制向量v。然后,基于获得的调制向量将模型参数调制为具备任务特异性的模型初始化参数。下面是具体的实现步骤:
Ⅰ、为了保证步骤S3得到的先验表征向量ωout和所需要调制的模型参数θ(
Figure BDA0003356392840000101
其中M表示参数的数量)的维度适配,本发明使用一系列门控函数gi,用来将ωout调制成与模型参数θ维度一致的调制向量
Figure BDA0003356392840000102
具体的实现方式为:
Figure BDA0003356392840000103
其中,M表示模型的网络参数的数量,φg表示每个门控函数的网络参数。
Ⅱ、使用调制向量
Figure BDA0003356392840000104
对模型的参数θ进行调制,生成最适合当前任务的模型初始化参数θnew,具体实现如下:
Figure BDA0003356392840000105
通过上述过程产生的模型参数θnew为适用于当前任务的模型初始化参数。
5)模型训练
本发明主要基于梯度下降对模型的参数训练,在小样本图像分类场景下,我们将具有32个通道的传统4层卷积神经网络作为基础元学习模型,并基于损失函数
Figure BDA0003356392840000106
进行参数更新,
Figure BDA0003356392840000107
的组成如下:
Figure BDA0003356392840000108
其中
Figure BDA0003356392840000109
表示图像分类模型的原始目标函数,在本实验中设置为交叉熵损失函数,即:
Figure BDA00033563928400001010
Figure BDA00033563928400001011
表示嵌入损失函数,即上述的任务特征提取网络的损失函数:
Figure BDA0003356392840000111
最后一项是L2正则化。即
Figure BDA0003356392840000112
我们添加它是为了保持模型参数在训练过程中的稳定性。而μ1、μ2分别表示不同的权重系数。在大部分的实验设置中,我们将参数μ1以及μ2设置固定值。
实验仿真与分析
基准方法,我们将我们的方法DCML分别与一些典型的基于优化和基于度量的算法进行比较。首先是共享参数方法MAML和Meta-SGD,然后是基于任务特定特征的方法MT-Net、MUMOMAML、BMAML、HSML。对于基于度量的元学习方法,我们选择基于集群的算法原型网络作为基线。所有实验设置都基于传统的小样本分类设置。
数据集介绍,根据基线的实验设置,我们使用四个数据集进行实验验证,它们是Bird、Texture、Aircraft和Fungi。下面对上述数据集进行简要介绍:
Bird Dataset:它是Caltech-UCSDBirds-200-2011(CUB-200-2011)的子数据集。CUB-200-2011是一个鸟类图像数据集,包含200种鸟类的11,788张照片。在我们的实验设置中,我们从该数据集中随机选择100个物种,每个物种包含60个样本。
Texture Dataset:它选自可描述纹理数据集(DTD),这是一个包含来自47个类别的5640张图像的纹理图像数据集。我们从每个类别分别选择了120张照片。
Aircraft Dataset:原始数据集是飞机细粒度视觉分类数据集(FGVC-Aircraft),这是一个用于飞机细粒度视觉分类的图像数据集。该数据集包含102种不同的飞机图像。在我们的实验设置中,我们随机选择了100类飞机图像,每类选择100张图片。
Fungi Dataset:我们通过从FGVCx-Fungi(Fungi)中随机选择100个物种,每个物种150张图像来组成该数据集,该数据集拥有近1,500种野生蘑菇物种的超过100,000张真菌图像。此外,为了方便筛选,我们提前将该数据集下样本数量少于150的类别剔除。
在实验中,每个数据集被分为元训练数据集、元验证数据集和元测试数据集。我们从每个子数据集中用几个样本随机选择一些类,每个样本大小调整为84×84×3,所选择的类分别用于元训练、元验证和元测试。对于鸟类数据集,64/16/20比例的种类分别被分为元训练数据集、元验证数据集和元测试数据集;对于Texture数据集,比例为30/7/10;而对于Aircraft数据集和Fungi数据集,这个比例也设置为64/16/20。
实验结果分析
系统采用Bird、Texture、Aircraft和Fungi作为基础数据集,分别完成了5-way 1-shot以及5-way 5-shot小样本图像分类实验,其中N-way k-shot表示对于每个数据集,我们通过随机抽样N个类和每个类的K个样本来构建任务。表1以及表2分别表示DCML在5-way1-shot以及5-way 5-shot场景下的分类准确率。此外,为了验证系统在连续适应场景下的性能,我们进行了5-way 1-shot连续适应实验。具体的,在实验过程中,每个阶段用来获取任务的数据集的丰富度是存在差异的,模型在15000轮迭代训练之前,采用的训练数据集为Bird、Texture,在15000轮次,我们添加了Aircraft,而在25000轮次,Fungi也被添加入备选数据集,而实验中每一个任务的选择,是基于当前训练数据集来进行选择。表4展示了DCML在连续适应实验中的分类准确率。根据上述结果,可以观察到:(1)DCML的性能明显由于其他的元学习方法MAML、Meta-SGD、MT-Net、MUMOMAML、BMAML、HSML、Prototype Network。表明了DCML的小样本场景下的有效性。(2)DCML具备明显优于其他方法的平均性能,表明了DCML对于不同种环境下的普适能力以及鲁棒性。(3)在连续适应实验场景中,DCML在大部分情况下都优于现有的元学习方法,表明DCML对于任务分别变化的特征捕捉以及适应能力。值得注意的是,DCML在实际执行时,如果需要添加一类新的任务特征聚类,模型只需要增加一个特征向量以及特征计数,就可以完成模型的扩展,相比于其他的方法,模型能够保证更少的内存消耗。
表1 5-way 1-shot图像分类实验结果
Figure BDA0003356392840000131
表2 5-way 5-shot图像分类实验结果
Figure BDA0003356392840000132
Figure BDA0003356392840000141
表3 5-way 1-shot连续适应实验结果
Figure BDA0003356392840000142
消融分析
为了验证系统中不同组件的有效性。我们对不同类型的组件和超参数进行了消融研究。首先,我们探索了评估网络δ的有效性,该网络用于生成重要性系数α。在这部分,实验分为两个部分。第一部分我们验证了评估网络σ的有效性,第二部分我们研究了评估网络中不同激活函数的影响。我们在消融研究中进行了5-way1-shot图像分类实验。
表4显示了我们的实验结果。结果表明,在没有评估网络的情况下,模型的分类准确率普遍低于有评估网络的其他两种情况。同时,DCML在5-way 1-shot图像分类场景下,使用softmax作为激活函数能够获得更高的分类准确度。
表4评估网络σ消融研究实验结果
Figure BDA0003356392840000151

Claims (8)

1.一种基于元学习的图像分类方法,其特征在于,包括以下步骤:
S1、对于当前的包含Ni个样本的任务
Figure FDA0003356392830000011
Figure FDA0003356392830000012
分别表示输入的当前任务的每一个样本及其标签,将当前任务样本分为训练样本
Figure FDA0003356392830000013
以及测试样本
Figure FDA0003356392830000014
其中,
Figure FDA0003356392830000015
根据输入的训练样本
Figure FDA0003356392830000016
训练样本特征提取器
Figure FDA0003356392830000017
以及任务特征提取器F,并基于样本特征提取器
Figure FDA0003356392830000018
以及任务特征提取器F获取当前任务的任务特征vi;其中,所述训练样本为图像;
S2、利用任务特征vi,更新对应的聚类特征向量集合
Figure FDA0003356392830000019
并输出最适配当前任务的聚类特征向量cω;其中,K表示聚类层聚类的蔟数,cm表示聚类网络层中的聚类特征向量;
S3、基于任务特征评估网络σ对聚类特征向量cω及当前任务特征vi进行相关性评估,输出对应的评估系数α,并基于该评估系数α输出适用于当前任务的表征向量ωout;所述任务特征评估网络σ包括多个级联的连接层;
S4、基于任务的表征向量ωout对卷积神经网络模型的参数进行更新,得到最优的模型初始化参数θnew
S5、基于所述初始化参数θnew训练并更新卷积神经网络模型参数,得到图像分类模型。
2.根据权利要求1所述的基于元学习的图像分类方法,其特征在于,步骤S1中,当前任务特征vi的表达式为:
Figure FDA00033563928300000110
其中,gij为对样本特征向量
Figure FDA00033563928300000111
进行编码得到的编码向量:
Figure FDA00033563928300000112
其中,RNNenc表示以当前任务样本特征向量
Figure FDA00033563928300000113
作为输入的循环神经网络;训练期间,将训练样本
Figure FDA0003356392830000021
输入到样本特征提取器
Figure FDA0003356392830000022
的输出即为对应的样本特征向量集合
Figure FDA0003356392830000023
Figure FDA0003356392830000024
为当前任务的第j个训练样本。
3.根据权利要求1所述的基于元学习的图像分类方法,其特征在于,步骤S2的具体实现过程包括:
1)对比当前任务特征向量vi与所有存储的聚类特征向量集合
Figure FDA0003356392830000025
的欧式距离d(cj,vi),并选出距离最小的聚类特征向量cω
2)若步骤1)中聚类特征向量cω与当前任务的任务特征向量vi之间的距离d(cω,vi)大于设定的距离阈值dmax,同时,当前次训练对应的训练损失Lnow与上一次训练的训练损失Llast变化大于设定的变化阈值γ,即:d(cω,vi)>dmax and Lnow/Llast>γ,则聚类中心个数加1,并添加一个新的聚类特征向量cK+1,利用公式cK+1=vi初始化新的聚类特征向量,并将ck+1加入到聚类特征向量集合
Figure FDA0003356392830000026
中,返回步骤1);若上述条件不满足,则利用公式
Figure FDA0003356392830000027
N'ω=Nω+1获取更新的聚类特征向量cω',Nω代表当前被选择的聚类特征向量的更新次数,直接输出c'ω
4.根据权利要求1所述的基于元学习的图像分类方法,其特征在于,步骤S3中,适用于当前任务的表征向量ωout的表达式为:
ωout=α·cω+(1-α)·vi
其中,
Figure FDA0003356392830000028
σ表示由全连接网络构成的评估网络,
Figure FDA0003356392830000029
表示将cω以及vi合并为一个向量并作为评估网络的输入参数。
5.根据权利要求1所述的基于元学习的图像分类方法,其特征在于,步骤S4中,
Figure FDA00033563928300000210
其中,
Figure FDA00033563928300000211
M表示模型参数的数量;gi为门控函数,由一组全连接网络层构成;φg表示门控函数的网络参数。
6.一种计算机装置,包括存储器、处理器及存储在存储器上的计算机程序;其特征在于,所述处理器执行所述计算机程序,以实现权利要求1~6之一所述方法的步骤。
7.一种计算机可读存储介质,其上存储有计算机程序/指令;其特征在于,所述计算机程序/指令被处理器执行时实现权利要求1~6之一所述方法的步骤。
8.一种计算机程序产品,包括计算机程序/指令;其特征在于,该计算机程序/指令被处理器执行时实现权利要求1~6之一所述方法的步骤。
CN202111352723.4A 2021-11-16 2021-11-16 基于元学习的图像分类方法、装置、产品及存储介质 Pending CN114067155A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111352723.4A CN114067155A (zh) 2021-11-16 2021-11-16 基于元学习的图像分类方法、装置、产品及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111352723.4A CN114067155A (zh) 2021-11-16 2021-11-16 基于元学习的图像分类方法、装置、产品及存储介质

Publications (1)

Publication Number Publication Date
CN114067155A true CN114067155A (zh) 2022-02-18

Family

ID=80272632

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111352723.4A Pending CN114067155A (zh) 2021-11-16 2021-11-16 基于元学习的图像分类方法、装置、产品及存储介质

Country Status (1)

Country Link
CN (1) CN114067155A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116563638A (zh) * 2023-05-19 2023-08-08 广东石油化工学院 一种基于情景记忆的图像分类模型优化方法和系统

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111539448A (zh) * 2020-03-17 2020-08-14 广东省智能制造研究所 一种基于元学习的少样本图像分类方法
WO2020249125A1 (zh) * 2019-06-14 2020-12-17 第四范式(北京)技术有限公司 用于自动训练机器学习模型的方法和系统
WO2021164625A1 (en) * 2020-02-17 2021-08-26 Huawei Technologies Co., Ltd. Method of training an image classification model

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2020249125A1 (zh) * 2019-06-14 2020-12-17 第四范式(北京)技术有限公司 用于自动训练机器学习模型的方法和系统
WO2021164625A1 (en) * 2020-02-17 2021-08-26 Huawei Technologies Co., Ltd. Method of training an image classification model
CN111539448A (zh) * 2020-03-17 2020-08-14 广东省智能制造研究所 一种基于元学习的少样本图像分类方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
孙勋: "基于机器学习的极化SAR图像分类", 《中国优秀硕士学位论文全文数据库 信息科技辑》, 15 June 2020 (2020-06-15), pages 136 - 589 *
蒋留兵;周小龙;姜风伟;车俐;: "基于改进匹配网络的单样本学习", 系统工程与电子技术, no. 06, 22 March 2019 (2019-03-22) *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116563638A (zh) * 2023-05-19 2023-08-08 广东石油化工学院 一种基于情景记忆的图像分类模型优化方法和系统
CN116563638B (zh) * 2023-05-19 2023-12-05 广东石油化工学院 一种基于情景记忆的图像分类模型优化方法和系统

Similar Documents

Publication Publication Date Title
CN112308158B (zh) 一种基于部分特征对齐的多源领域自适应模型及方法
Liu et al. Learning spatio-temporal representations for action recognition: A genetic programming approach
CN109993100B (zh) 基于深层特征聚类的人脸表情识别的实现方法
US20120093396A1 (en) Digital image analysis utilizing multiple human labels
CN113076994B (zh) 一种开集域自适应图像分类方法及系统
CN113222011B (zh) 一种基于原型校正的小样本遥感图像分类方法
Bochinski et al. Deep active learning for in situ plankton classification
WO2022062419A1 (zh) 基于非督导金字塔相似性学习的目标重识别方法及系统
CN113128478B (zh) 模型训练方法、行人分析方法、装置、设备及存储介质
CN111400494B (zh) 一种基于GCN-Attention的情感分析方法
CN113761259A (zh) 一种图像处理方法、装置以及计算机设备
CN111079837B (zh) 一种用于二维灰度图像检测识别分类的方法
CN109902662A (zh) 一种行人重识别方法、系统、装置和存储介质
CN112232395B (zh) 一种基于联合训练生成对抗网络的半监督图像分类方法
CN113569895A (zh) 图像处理模型训练方法、处理方法、装置、设备及介质
Li et al. Dating ancient paintings of Mogao Grottoes using deeply learnt visual codes
CN109472733A (zh) 基于卷积神经网络的图像隐写分析方法
CN114511739A (zh) 一种基于元迁移学习的任务自适应的小样本图像分类方法
CN110569780A (zh) 一种基于深度迁移学习的高精度人脸识别方法
CN114821237A (zh) 一种基于多级对比学习的无监督船舶再识别方法及系统
CN111126155B (zh) 一种基于语义约束生成对抗网络的行人再识别方法
CN116310466A (zh) 基于局部无关区域筛选图神经网络的小样本图像分类方法
CN114067155A (zh) 基于元学习的图像分类方法、装置、产品及存储介质
CN114329031A (zh) 一种基于图神经网络和深度哈希的细粒度鸟类图像检索方法
CN117079017A (zh) 可信的小样本图像识别分类方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination