CN114067155B - 基于元学习的图像分类方法、装置、产品及存储介质 - Google Patents
基于元学习的图像分类方法、装置、产品及存储介质 Download PDFInfo
- Publication number
- CN114067155B CN114067155B CN202111352723.4A CN202111352723A CN114067155B CN 114067155 B CN114067155 B CN 114067155B CN 202111352723 A CN202111352723 A CN 202111352723A CN 114067155 B CN114067155 B CN 114067155B
- Authority
- CN
- China
- Prior art keywords
- task
- feature
- training
- model
- feature vector
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 43
- 238000003860 storage Methods 0.000 title claims abstract description 9
- 239000013598 vector Substances 0.000 claims abstract description 92
- 238000012549 training Methods 0.000 claims abstract description 79
- 238000011156 evaluation Methods 0.000 claims abstract description 32
- 238000012512 characterization method Methods 0.000 claims abstract description 20
- 230000006870 function Effects 0.000 claims description 24
- 238000004590 computer program Methods 0.000 claims description 14
- 238000012360 testing method Methods 0.000 claims description 9
- 230000008859 change Effects 0.000 claims description 7
- 238000013527 convolutional neural network Methods 0.000 claims description 6
- 238000013528 artificial neural network Methods 0.000 claims description 3
- 238000013145 classification model Methods 0.000 claims description 3
- 230000000306 recurrent effect Effects 0.000 claims description 2
- 230000006978 adaptation Effects 0.000 abstract description 8
- 238000009826 distribution Methods 0.000 description 14
- 238000002474 experimental method Methods 0.000 description 13
- 238000000605 extraction Methods 0.000 description 12
- 238000004422 calculation algorithm Methods 0.000 description 11
- 208000014165 immunodeficiency 21 Diseases 0.000 description 11
- 230000000694 effects Effects 0.000 description 8
- 230000008569 process Effects 0.000 description 8
- 241000233866 Fungi Species 0.000 description 6
- 238000013459 approach Methods 0.000 description 6
- 241000894007 species Species 0.000 description 5
- 238000002679 ablation Methods 0.000 description 4
- 238000004458 analytical method Methods 0.000 description 3
- 238000010801 machine learning Methods 0.000 description 3
- 238000012795 verification Methods 0.000 description 3
- 101100455978 Arabidopsis thaliana MAM1 gene Proteins 0.000 description 2
- 230000004913 activation Effects 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 230000000007 visual effect Effects 0.000 description 2
- 235000001674 Agaricus brunnescens Nutrition 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000002538 fungal effect Effects 0.000 description 1
- 238000003064 k means clustering Methods 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 238000002360 preparation method Methods 0.000 description 1
- 238000011084 recovery Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 238000004088 simulation Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 238000010200 validation analysis Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
- G06F18/232—Non-hierarchical techniques
- G06F18/2321—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
- G06F18/23213—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Molecular Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种基于元学习的图像分类方法、装置、产品及存储介质,对于当前的任务根据输入的训练样本训练一个样本特征提取器以及任务特征提取器F,获取对应的任务特征vi。将vi输入到聚类网络层,通过最小距离匹配以及在线聚类的方式更新对应的聚类特征向量集合并输出最适配当前任务的聚类特征向量cω。基于任务特征评估网络σ对任务特征向量cω及当前任务特征vi进行相关性评估,输出对应的评估系数α,并基于该系数计算适用于当前任务的表征向量ωout。基于ωout对模型的参数进行调制,得到最优的模型初始化参数θnew。模型基于该初始化参数θnew开始训练更新。本发明能够在不进行人为干扰的情况下完成任务先验知识的更新以及相关性适配并高效复用以加速新的图像分类任务的学习。
Description
技术领域
本发明涉及元学习和图像识别领域,特别是一种基于元学习的图像分类方法、装置、产品及存储介质。
背景技术
近年来,深度学习领域的发展让机器学习步入一个新的阶段,深度神经网络的发展也使得机器学习模型能够获得优异的性能。然而,现有的机器学习模型大多数需要大量被标记的训练样本以及大量的训练时间。然而,对应于现实生活中,样本的标签收集是困难的,甚至有时候样本自身的丰富程度也存在差异,有些任务可能只有少量的训练样本。同时,对于时效性较强的模型训练,也不允许使用大量的时间进行训练。针对上述问题,一种解决方法就是通过知识迁移的方式,借助其他任务的学习经验,来增加新任务的学习效率,同时减少训练样本需求量。元学习旨在通过复用在过去学习过的任务中的学习经验,进而加速新任务的学习进程。
现有的元学习算法主要分为基于度量,基于确定模型以及基于梯度更新的元学习算法,本发明主要关注基于梯度更新的元学习算法。现有的基于梯度更新的元学习算法通过为当前任务生成最优模型初始化参数的方式来加速新任务的学习,由于上述算法所产生的模型初始化参数只考虑当前模型过去所学习过的所有任务,产生使得对于过去所有的任务都能快速适应的模型初始化参数。上述方式主要存在两大缺陷:①模型并没有考虑新任务是否与过去学习的任务相关,而初始化参数只与过去学习的任务相关。②模型过去学习的任务可能自身存在分布的差异,生成对应于所有任务都最有的模型初始化参数可能是使得模型缺乏特异性,导致在新任务上表现不佳。因此,考虑新任务与过去学习任务的相关性以及差异性,将是提高元学习算法学习效率的一个高效手段。
问题定义如图1所示,假设有一系列任务从任务概率分布中随机抽样而来,每个任务有多个样本每个任务的样本通常分为两个部分称为训练集和测试集在元学习场景中,模型训练通常分为元训练阶段和元测试阶段。给定一个元学习器f和模型参数θ。元训练的目标是在训练任务上训练模型f来存储先验知识,并通过使用先验知识使模型学习得更快,而元测试则是为了验证模型的拟合效果以及泛化能力。在小样本图像识别场景下,一个任务通常是指对多个不同类别的图像进行识别,元学习的目标就是在经过一系列图像识别任务的元训练阶段后之后,使得元学习模型能够有足够强的泛化能力,使其能够高效学习新任务。
现有的基于梯度的元学习算法没有考虑新任务与过去学习过的任务群里的分布差异以及相关性导致算法的性能下降。尽管已经有几项研究工作考虑新任务自身的分布对模型初始化参数的影响。但是根据人类的学习经验,先验知识与当前任务的相关性以及先验知识自身的内部差异是应该被考虑在内的。
在现有的相关文献中,公开号为CN111539448A的发明专利公开了一种基于元学习的少样本图像分类方法,该方法考虑对所有的训练集数据进行k-means聚类,并基于聚类的结果对待适应的新任务进行适配。但是,该方法考虑需要一次性拿到所有的训练任务数据。而在实际的模型训练中,元学习的任务一般为序列到达,模型一次性只能接触一个任务,并在不断的学习过程中更新、储存并复用有效的先验知识。公开号为CN111724083A的发明专利公开了一种金融风险识别模型的训练方法、装置、计算机设备及介质。该方法通过对训练数据进行聚类并基于新任务与不同类别的训练数据的距离矩阵来确定并训练多个对应的分类器,旨在为不同类别的任务类别提供多个特定的分类器,并用于金融风险的快速预测与评估。多个分类器的方法使得新的评测数据能够使用更加有效的分类器,但这同时也限制了模型的可扩展性,如果存在新的预测任务与预处理时使用的任务存在较大差异,那么基于预训练得到的分类器可能就不适用于新任务。
发明内容
本发明所要解决的技术问题是,针对现有技术不足,提供一种基于元学习的图像分类方法、装置、产品及存储介质,能够适用于任何基于梯度更新的元学习模型,并基于上述元学习模型快速适应新的图像分类任务并得到优异的图像分类效果。
为解决上述技术问题,本发明所采用的技术方案是:一种基于元学习的图像分类方法,包括以下步骤:
S1、对于当前的包含Ni个样本的任务 分别表示输入的当前任务的每一个样本及其标签,将当前任务样本分为训练样本以及测试样本其中,根据输入的训练样本训练样本特征提取器以及任务特征提取器F,并基于样本特征提取器以及任务特征提取器F获取当前任务的任务特征vi;其中,所述训练样本为图像;
S2、利用任务特征vi,更新对应的聚类特征向量集合并输出最适配当前任务的聚类特征向量cω;其中,K表示聚类网络层中聚类的蔟数,cm表示聚类网络层中的聚类特征向量;
S3、基于任务特征评估网络σ对聚类特征向量cω及当前任务特征vi进行相关性评估,输出对应的评估系数α,并基于该评估系数α输出适用于当前任务的表征向量ωout;所述任务特征评估网络σ包括多个级联的连接层;
S4、基于任务的表征向量ωout对卷积神经网络模型的参数进行更新,得到最优的模型初始化参数θnew;
S5、基于所述初始化参数θnew训练并更新卷积神经网络模型参数,得到图像分类模型。
本发明具有如下优势:
(1)步骤S1的任务特征提取器基于自动编码器解码器将当前任务的多个样本特征整合为一个任务特征向量,为模型提供一个可靠且维度有限的任务表征,避免了网络输入参数过大带来的训练复杂的影响。
(2)步骤S2基于与当前任务最相似的先验知识表征进行了更新,同时输出更新后的聚类特征向量。保证了模型中输出的先验知识与当前任务具备更高的相关性。同时,聚类特征的在线扩充技术的使用也使得模型在遇到其他分布的任务时,能够自动扩展聚类个数,既保证了模型训练的稳定性,也保证了参数更新的可靠性。
(3)步骤S3评估了当前任务的特征向量与任务聚类层输出的聚类特征向量的相关性,并根据结果生成先验表征向量,从而进一步保证了先验知识的可靠性,同时也保证了任务聚类层所保存的任务聚类特征能够被准确更新。
(4)步骤S4通过使用先验表征向量对模型初始化参数进行更新,保证了更新后的模型初始化参数能够更加适用于当前任务,保证其能够在更少的训练次数下获得更好的分类效果。
(5)步骤S5基于调制后的模型初始化参数完成模型的训练,因为算法没有明确要求模型的具体实现以及网络构造,因此使得本发明能够适用于任何基于梯度更新的元学习模型。
步骤S1中,当前任务特征vi的表达式为:
其中,gi,j为对样本特征向量进行编码得到的编码向量:
其中,RNNenc表示以当前任务样本特征向量作为输入的循环神经网络;训练期间,将训练样本输入到样本特征提取器 的输出即为对应的样本特征向量集合 为当前任务的第j个训练样本。
通过在S1阶段对任务的样本进行特征提取,可以将当前任务的输入样本,即图片从二维矩阵提取成一维的向量,可以在尽可能保证数据信息完整的情况下解决因为输入维度过大导致的模型训练复杂的问题,同时,任务特征vi被表示为多个样本特征的均值,可以最大可能的降低因为样本输入顺序带来的误差,也使得所提取的任务特征更加具备可靠性。
步骤S2的具体实现过程包括:
1)对比当前任务特征向量vi与所有存储的聚类特征向量集合的欧式距离d(cj,vi),并选出距离最小的聚类特征向量cω;
2)若步骤1)中聚类特征向量cω与当前任务的任务特征向量vi之间的距离d(cω,vi)大于设定的距离阈值dmax,同时,当前次训练对应的训练损失Lnow与上一次训练的训练损失Llast变化大于设定的变化阈值γ,即:d(cω,vi)>dmax and Lnow/Llast>γ,则聚类中心个数加1,并添加一个新的聚类特征向量cK+1,利用公式cK+1=vi初始化新的聚类特征向量,并将ck+1加入到聚类特征向量集合中,返回步骤1);若上述条件不满足,则利用公式N'ω=Nω+1获取更新的聚类特征向量cω',Nω代表当前被选择的聚类特征向量的更新次数,直接输出c'ω。
通过步骤S2,可以针对性的选择和当前任务最相关的任务聚类特征并单独更新对应的聚类特征,保证了每个任务聚类特征都能够存储对应于不同任务分布的先验知识,因此也保证了在后续的步骤中能够得到更加适用于当前任务的先验表征向量;另一方面,步骤2中的2)提供了一个错误检测与恢复机制让模型能够自主的判断当前任务的分布与过去学习过的任务分布之间的差异性,使得模型在遇到属于未知任务分布的新任务时能够自动化的增加新的聚类特征向量,同时也保证现有保存的聚类信息不被干扰,保证了模型训练的稳定性以及先验知识的可靠性。
步骤S3中,适用于当前任务的表征向量ωout的表达式为:
ωout=α·cω+(1-α)·vi;
其中,σ表示由全连接网络构成的评估网络,表示将cω以及vi合并为一个向量并作为评估网络的输入参数。
步骤S3基于自动生成的相关性系数α中对当前任务特征vi以及被选择的任务聚类特征cω进行线性调制。使得当前任务可以自主控制先验表征ωout中先验知识与当前任务信息之间的权重,进一步保证了当前任务能够基于与自身最相关的知识产生可靠的模型初始化参数,进而有效减少模型的训练时间。
步骤S4中,其中,M表示模型参数的数量;gi为门控函数,由一组全连接网络层构成;φg表示门控函数的网络参数。
步骤S4使用了一系列门控函数,用以将S3步骤获得的表征向量ωout调制为不同的形状,用以匹配模型中不同形状的网络参数;同时,门控函数的参数也是基于梯度下降更新,这也使得同一个任务表征向量对不同的模型参数有各异的调制效果,保证了每一个模型参数都能有适用于当前任务的初始化内容。
步骤S5中,模型训练时基于梯度下降对模型的参数进行更新,其中,更新过程中使用的损失函数为上述损失函数包含三部分,分别是当前任务的损失函数 在图像分类任务中设定为交叉熵损失函数 为步骤S1中的任务特征提取网络的损失函数,即而为正则化损失函数,即其中θ代表模型的参数。
步骤S5中,模型基于梯度下降对调制后的初始化参数进行更新,由于模型参数并没有明确给定,本方法适用于任意基于梯度更新的元学习算法。这也使得本发明不仅能使用于图像分类任务,针对不同的学习目标,仅需要修改对应的任务损失函数,就可以完成模型网络的训练准备工作。
本发明还提供了一种计算机装置,包括存储器、处理器及存储在存储器上的计算机程序;所述处理器执行所述计算机程序,以实现本发明方法的步骤;
本发明还提供了一种计算机可读存储介质,其上存储有计算机程序/指令;其特征在于,所述计算机程序/指令被处理器执行时实现本发明方法的步骤。
本发明还提供了一种计算机程序产品,包括计算机程序/指令;其特征在于,该计算机程序/指令被处理器执行时实现本发明方法的步骤;
与现有技术相比,本发明所具有的有益效果为:
(1)本发明能够不依赖额外的信息对先验知识进行准确区分、更新并高效复用。因为来源于不同任务分布的图像分类任务之间共享较少的先验知识,通过在线聚类对来自不同任务分布的任务先验知识进行聚类区分,可以只筛选相关性高的先验知识来帮助新任务快速学习,进而提高新任务的学习效率。
(2)本发明所计算的模型初始化参数综合考虑了新任务与先验知识之间的相关性并给出与当前任务最相关的先验知识,因此本发明能够为基于梯度的元学习算法提供更加可靠的模型初始化参数,使得模型能够在新任务中快速学习,进而在有限的更新次数下,获得更加优异的图像分类效果。
(3)本发明的模型对于未出现的任务分布所产生的任务,即新任务的任务分布与训练任务的分布存在较大差距的情况下,能够自主识别并扩充保存的先验知识聚簇,进而高效复用相关性强的先验知识,减少错误先验知识带来的负面影响。
(4)本发明的模型能够在训练任务序列到达的情况下完成先验知识的存储、更新及复用,进而高效完成新任务的快速适应。因为本发明的模型基于在线聚类以及动态扩展的方式处理任务的先验知识,符合现实任务的相关设定;
附图说明
图1为本发明的元训练任务以及元测试任务示意图;
图2为基于元学习的图像分类方法的执行流程示意图。
具体实施方式
本发明的小样本图像分类场景下的通用的基于在线聚类的元学习方法包括以下步骤:
如图2所示,本发明主要由五个阶段组成。在特征提取阶段,通过将输入的图片进行特征提取,将多个类别中的若干张图片编码为任务特征向量,该过程能够通过梯度更新的方式对特征提取器进行优化,使得模型能够逐渐提高所提取任务特征的可靠性。在任务特征聚类阶段,通过将提取的任务特征输入到任务聚类层并找出与当前任务最相近的任务聚类特征并更新该聚类特征,该步骤能够输出与当前任务最具备相关性的过往任务分布的任务聚类特征。在相似性评估阶段,本发明使用一个评估网络生成任务聚类特征与当前任务特征的相关性系数,并基于这个系数计算模型的先验表征向量,即任务特征向量以及聚类特征向量的加权和,基于上述操作,本发明同时考究了当前的任务以及过去学习过的任务的重要性,能让模型的初始化参数具备更多特异性。在参数调制阶段,本发明通过一系列门控函数,将先验表征向量调制为与模型参数相同的维度,并将调制后的先验表征向量和对应的模型参数进行相乘,最后得到新的模型初始化参数,该参数能够更加适用于当前的任务。模型训练阶段将调制后的模型参数传入具体的模型并基于梯度下降开始训练,使得模型能够在有限次数的更新后,得到优异的分类结果。
1)特征提取
特征提取主要分为为两个步骤;样本特征提取以及任务特征提取。一个任务通常含有多个不同类别的样本,在样本特征提取阶段,我们采用样本特征提取器完成样本特征提取,是基于聚合器嵌入的方式实现,具体实现为一个由两个卷积层和两个全连接层组成的块。而在任务特征提取阶段,我们采用的是自动编解码器F,F中的自动编码器和自动解码器都是基于RNN实现,通过上述两层特征提取器,就可以完成从任务样本输入到任务特征的输出。
不失一般性,考虑当前为第i个任务该任务有多个训练样本为了保证训练结果的可靠性,我们只使用训练样本完成特征提取。
首先,将每一个训练样本按序输入到样本特征提取器并输出对应的样本特征集合 一般是一个1×n的向量,在小样本图像分类中,n的值设定为128。然后,将获取的样本特征集合输入到任务特征提取器F中,对应的输出对应的任务特征向量gi,j,为了避免应为样本的输入顺序对特征提取网络训练带来的负面影响,考虑将样本特征进行随机打乱,并输入到任务特征网络,得到多个任务特征向量,具体的计算方式为:
最后,将多个任务样本特征gi,j的均值vi作为当前按任务的特征向量输出:
我们基于梯度下降的方式来训练上述任务特征提取器。为了保证所提取的任务特征能够可靠表征当前任务,我们将获取的任务特征通过解码器进行还原为Fdec(gi,j),并根据解码的样本特征与原始任务样本特征的差值来训练特征提取器网络,具体的损失函数为:
2)任务特征聚类
Ⅰ、将从任务特征提取器所生成的当前任务的任务特征向量vi输入到任务聚类层,任务聚类层通过对比当前任务特征向量vi与所有存储的聚类特征向量集合基于欧式距离进行距离d(cj,vi)比对,并选出对应的距离最小的聚类特征向量cω,即
Ⅱ、为了保证所选择的聚类特征向量的可靠性,本方法考虑将当前所求得的最小距离与最近两批次的训练损失变化作为评估条件,若步骤Ⅰ中对应距离最小的聚类特征向量cω与当前任务的任务特征向量vi之间的距离d(cω,vi)大于设定的距离阈值dmax,同时,当前批次对应的训练损失Lnow与上一批次的训练损失Llast变化大于设定的变化阈值γ,即:
d(cj,vi)>dmax and Lnow/Llast>γ;
执行聚类层的自动扩充操作,具体的执行方式如下:
①设置聚类中心个数K=K+1,并添加一个新的簇向量cK+1并初始化新的簇
cK+1=vi,NK+1=1;
②将ck+1添加到聚类特征向量集合重置模型并将存储参数从缓存重新加载卷积神经网络模型并返回步骤Ⅰ。
当上述条件不满足时,则对所选择的任务特征向量ω进行更新操作,具体方式如下:
其中Nω代表当前被选择的聚类特征向量cω的更新次数。最后,将更新后的任务特征向量c'ω作为任务聚类层的输出。
在小样本图像分类场景下,聚类层的初始聚类蔟数被设定为4,而距离阈值dmax以及损失变化阈值γ被设定为0.80以及1.25。
3)相关性评估
相关性评估主要分为两个阶段,首先时根据任务特征评估网络σ获取对应的相关性系数α,然后根据输出的系数对两个特征进行线性组合,具体如下:
Ⅰ、本发明使用一个任务特征评估网络σ对任务特征向量cω及当前任务特征vi进行相关性评估。评估网络σ同时接受当前任务特征vi以及任务聚类层输出cω作为网络的输入参数,并输出对应的相关性系数α,具体实现为:
其中σ基于一个全连接网络实现,表示将任务特征vi以及上层任务聚类层输出cω合并为一个张量并输入任务特征评估网络σ,输出的α是一个值位于0–1区间内的向量。
Ⅱ、基于特征评估网络σ所输出的相关性系数α,通过计算得到当前任务的先验表征向量ωout,具体计算方式为:
ωout=α·cω+(1-α)·vi;
通过上述步骤,我们可以获得一个同时关注过去先验知识以及当前新任务的特征的任务表征向量ωout,基于该向量对模型参数调制,就能产生对当前任务更加适合的模型初始化参数。
4)参数调制
参数调制阶段分为两个步骤,首先,由于模型本身的结构性质,模型内可能会存在多种不同形状,甚至不同维度的网络参数,为了将任务表征向量ωout作用于每一个可训练的模型参数,本发明基于一系列门控函数g将ωout调制为对应维度的调制向量v。然后,基于获得的调制向量将模型参数调制为具备任务特异性的模型初始化参数。下面是具体的实现步骤:
Ⅰ、为了保证步骤S3得到的先验表征向量ωout和所需要调制的模型参数θ(其中M表示参数的数量)的维度适配,本发明使用一系列门控函数gi,用来将ωout调制成与模型参数θ维度一致的调制向量具体的实现方式为:
其中,M表示模型的网络参数的数量,φg表示每个门控函数的网络参数。
Ⅱ、使用调制向量对模型的参数θ进行调制,生成最适合当前任务的模型初始化参数θnew,具体实现如下:
通过上述过程产生的模型参数θnew为适用于当前任务的模型初始化参数。
5)模型训练
本发明主要基于梯度下降对模型的参数训练,在小样本图像分类场景下,我们将具有32个通道的传统4层卷积神经网络作为基础元学习模型,并基于损失函数进行参数更新,的组成如下:
其中表示图像分类模型的原始目标函数,在本实验中设置为交叉熵损失函数,即:
表示嵌入损失函数,即上述的任务特征提取网络的损失函数:
最后一项是L2正则化。即我们添加它是为了保持模型参数在训练过程中的稳定性。而μ1、μ2分别表示不同的权重系数。在大部分的实验设置中,我们将参数μ1以及μ2设置固定值。
实验仿真与分析
基准方法,我们将我们的方法DCML分别与一些典型的基于优化和基于度量的算法进行比较。首先是共享参数方法MAML和Meta-SGD,然后是基于任务特定特征的方法MT-Net、MUMOMAML、BMAML、HSML。对于基于度量的元学习方法,我们选择基于集群的算法原型网络作为基线。所有实验设置都基于传统的小样本分类设置。
数据集介绍,根据基线的实验设置,我们使用四个数据集进行实验验证,它们是Bird、Texture、Aircraft和Fungi。下面对上述数据集进行简要介绍:
Bird Dataset:它是Caltech-UCSDBirds-200-2011(CUB-200-2011)的子数据集。CUB-200-2011是一个鸟类图像数据集,包含200种鸟类的11,788张照片。在我们的实验设置中,我们从该数据集中随机选择100个物种,每个物种包含60个样本。
Texture Dataset:它选自可描述纹理数据集(DTD),这是一个包含来自47个类别的5640张图像的纹理图像数据集。我们从每个类别分别选择了120张照片。
Aircraft Dataset:原始数据集是飞机细粒度视觉分类数据集(FGVC-Aircraft),这是一个用于飞机细粒度视觉分类的图像数据集。该数据集包含102种不同的飞机图像。在我们的实验设置中,我们随机选择了100类飞机图像,每类选择100张图片。
Fungi Dataset:我们通过从FGVCx-Fungi(Fungi)中随机选择100个物种,每个物种150张图像来组成该数据集,该数据集拥有近1,500种野生蘑菇物种的超过100,000张真菌图像。此外,为了方便筛选,我们提前将该数据集下样本数量少于150的类别剔除。
在实验中,每个数据集被分为元训练数据集、元验证数据集和元测试数据集。我们从每个子数据集中用几个样本随机选择一些类,每个样本大小调整为84×84×3,所选择的类分别用于元训练、元验证和元测试。对于鸟类数据集,64/16/20比例的种类分别被分为元训练数据集、元验证数据集和元测试数据集;对于Texture数据集,比例为30/7/10;而对于Aircraft数据集和Fungi数据集,这个比例也设置为64/16/20。
实验结果分析
系统采用Bird、Texture、Aircraft和Fungi作为基础数据集,分别完成了5-way 1-shot以及5-way 5-shot小样本图像分类实验,其中N-way k-shot表示对于每个数据集,我们通过随机抽样N个类和每个类的K个样本来构建任务。表1以及表2分别表示DCML在5-way1-shot以及5-way 5-shot场景下的分类准确率。此外,为了验证系统在连续适应场景下的性能,我们进行了5-way 1-shot连续适应实验。具体的,在实验过程中,每个阶段用来获取任务的数据集的丰富度是存在差异的,模型在15000轮迭代训练之前,采用的训练数据集为Bird、Texture,在15000轮次,我们添加了Aircraft,而在25000轮次,Fungi也被添加入备选数据集,而实验中每一个任务的选择,是基于当前训练数据集来进行选择。表4展示了DCML在连续适应实验中的分类准确率。根据上述结果,可以观察到:(1)DCML的性能明显由于其他的元学习方法MAML、Meta-SGD、MT-Net、MUMOMAML、BMAML、HSML、Prototype Network。表明了DCML的小样本场景下的有效性。(2)DCML具备明显优于其他方法的平均性能,表明了DCML对于不同种环境下的普适能力以及鲁棒性。(3)在连续适应实验场景中,DCML在大部分情况下都优于现有的元学习方法,表明DCML对于任务分别变化的特征捕捉以及适应能力。值得注意的是,DCML在实际执行时,如果需要添加一类新的任务特征聚类,模型只需要增加一个特征向量以及特征计数,就可以完成模型的扩展,相比于其他的方法,模型能够保证更少的内存消耗。
表1 5-way 1-shot图像分类实验结果
表2 5-way 5-shot图像分类实验结果
表3 5-way 1-shot连续适应实验结果
消融分析
为了验证系统中不同组件的有效性。我们对不同类型的组件和超参数进行了消融研究。首先,我们探索了评估网络δ的有效性,该网络用于生成重要性系数α。在这部分,实验分为两个部分。第一部分我们验证了评估网络σ的有效性,第二部分我们研究了评估网络中不同激活函数的影响。我们在消融研究中进行了5-way1-shot图像分类实验。
表4显示了我们的实验结果。结果表明,在没有评估网络的情况下,模型的分类准确率普遍低于有评估网络的其他两种情况。同时,DCML在5-way 1-shot图像分类场景下,使用softmax作为激活函数能够获得更高的分类准确度。
表4评估网络σ消融研究实验结果
Claims (7)
1.一种基于元学习的图像分类方法,其特征在于,包括以下步骤:
S1、对于当前的包含Ni个样本的任务分别表示输入的当前任务的每一个样本及其标签,将当前任务样本分为训练样本以及测试样本其中,根据输入的训练样本训练样本特征提取器以及任务特征提取器F,并基于样本特征提取器以及任务特征提取器F获取当前任务的任务特征vi;其中,所述训练样本为图像;
S2、利用任务特征vi,更新对应的聚类特征向量集合并输出最适配当前任务的聚类特征向量cω;其中,K表示聚类层聚类的蔟数,cm表示聚类网络层中的聚类特征向量;
S3、基于任务特征评估网络σ对聚类特征向量cω及当前任务特征vi进行相关性评估,输出对应的评估系数α,并基于该评估系数α输出适用于当前任务的表征向量ωout;所述任务特征评估网络σ包括多个级联的连接层;
S4、基于任务的表征向量ωout对卷积神经网络模型的参数进行更新,得到最优的模型初始化参数θnew;
S5、基于所述初始化参数θnew训练并更新卷积神经网络模型参数,得到图像分类模型;
步骤S1中,当前任务特征vi的表达式为:
其中,gi,j为对样本特征向量进行编码得到的编码向量:
其中,RNNenc表示以当前任务样本特征向量作为输入的循环神经网络;训练期间,将训练样本输入到样本特征提取器 的输出即为对应的样本特征向量集合 为当前任务的第j个训练样本。
2.根据权利要求1所述的基于元学习的图像分类方法,其特征在于,步骤S2的具体实现过程包括:
1)对比当前任务特征向量vi与所有存储的聚类特征向量集合的欧式距离d(cj,vi),并选出距离最小的聚类特征向量cω;
2)若步骤1)中聚类特征向量cω与当前任务的任务特征向量vi之间的距离d(cω,vi)大于设定的距离阈值dmax,同时,当前次训练对应的训练损失Lnow与上一次训练的训练损失Llast变化大于设定的变化阈值γ,即:d(cω,vi)>dmax and Lnow/Llast>γ,则聚类中心个数加1,并添加一个新的聚类特征向量cK+1,利用公式cK+1=vi初始化新的聚类特征向量,并将ck+1加入到聚类特征向量集合中,返回步骤1);若上述条件不满足,则利用公式N'ω=Nω+1获取更新的聚类特征向量cω',Nω代表当前被选择的聚类特征向量的更新次数,直接输出c'ω。
3.根据权利要求1所述的基于元学习的图像分类方法,其特征在于,步骤S3中,适用于当前任务的表征向量ωout的表达式为:
ωout=α·cω+(1-α)·vi;
其中,σ表示由全连接网络构成的评估网络,表示将cω以及vi合并为一个向量并作为评估网络的输入参数。
4.根据权利要求1所述的基于元学习的图像分类方法,其特征在于,步骤S4中,其中,M表示模型参数的数量;gi为门控函数,由一组全连接网络层构成;φg表示门控函数的网络参数。
5.一种计算机装置,包括存储器、处理器及存储在存储器上的计算机程序;其特征在于,所述处理器执行所述计算机程序,以实现权利要求1~4之一所述方法的步骤。
6.一种计算机可读存储介质,其上存储有计算机程序/指令;其特征在于,所述计算机程序/指令被处理器执行时实现权利要求1~4之一所述方法的步骤。
7.一种计算机程序产品,包括计算机程序/指令;其特征在于,该计算机程序/指令被处理器执行时实现权利要求1~4之一所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111352723.4A CN114067155B (zh) | 2021-11-16 | 2021-11-16 | 基于元学习的图像分类方法、装置、产品及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111352723.4A CN114067155B (zh) | 2021-11-16 | 2021-11-16 | 基于元学习的图像分类方法、装置、产品及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114067155A CN114067155A (zh) | 2022-02-18 |
CN114067155B true CN114067155B (zh) | 2024-07-19 |
Family
ID=80272632
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111352723.4A Active CN114067155B (zh) | 2021-11-16 | 2021-11-16 | 基于元学习的图像分类方法、装置、产品及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114067155B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116563638B (zh) * | 2023-05-19 | 2023-12-05 | 广东石油化工学院 | 一种基于情景记忆的图像分类模型优化方法和系统 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111539448A (zh) * | 2020-03-17 | 2020-08-14 | 广东省智能制造研究所 | 一种基于元学习的少样本图像分类方法 |
WO2020249125A1 (zh) * | 2019-06-14 | 2020-12-17 | 第四范式(北京)技术有限公司 | 用于自动训练机器学习模型的方法和系统 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
GB2592076B (en) * | 2020-02-17 | 2022-09-07 | Huawei Tech Co Ltd | Method of training an image classification model |
-
2021
- 2021-11-16 CN CN202111352723.4A patent/CN114067155B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2020249125A1 (zh) * | 2019-06-14 | 2020-12-17 | 第四范式(北京)技术有限公司 | 用于自动训练机器学习模型的方法和系统 |
CN111539448A (zh) * | 2020-03-17 | 2020-08-14 | 广东省智能制造研究所 | 一种基于元学习的少样本图像分类方法 |
Also Published As
Publication number | Publication date |
---|---|
CN114067155A (zh) | 2022-02-18 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113076994B (zh) | 一种开集域自适应图像分类方法及系统 | |
Mathur et al. | Crosspooled FishNet: transfer learning based fish species classification model | |
US8379994B2 (en) | Digital image analysis utilizing multiple human labels | |
WO2022062419A1 (zh) | 基于非督导金字塔相似性学习的目标重识别方法及系统 | |
CN110619059B (zh) | 一种基于迁移学习的建筑物标定方法 | |
CN113128478B (zh) | 模型训练方法、行人分析方法、装置、设备及存储介质 | |
WO2022218396A1 (zh) | 图像处理方法、装置和计算机可读存储介质 | |
CN109743642B (zh) | 基于分层循环神经网络的视频摘要生成方法 | |
CN106529570B (zh) | 基于深度脊波神经网络的图像分类方法 | |
CN115131618B (zh) | 基于因果推理的半监督图像分类方法 | |
CN113761259A (zh) | 一种图像处理方法、装置以及计算机设备 | |
CN110569780A (zh) | 一种基于深度迁移学习的高精度人脸识别方法 | |
CN111127360A (zh) | 一种基于自动编码器的灰度图像迁移学习方法 | |
CN112232395B (zh) | 一种基于联合训练生成对抗网络的半监督图像分类方法 | |
CN109376736A (zh) | 一种基于深度卷积神经网络的视频小目标检测方法 | |
CN112115996B (zh) | 图像数据的处理方法、装置、设备及存储介质 | |
CN109829414A (zh) | 一种基于标签不确定性和人体组件模型的行人再识别方法 | |
CN111126155B (zh) | 一种基于语义约束生成对抗网络的行人再识别方法 | |
CN111079837A (zh) | 一种用于二维灰度图像检测识别分类的方法 | |
CN114821237A (zh) | 一种基于多级对比学习的无监督船舶再识别方法及系统 | |
CN114329031A (zh) | 一种基于图神经网络和深度哈希的细粒度鸟类图像检索方法 | |
CN111709442A (zh) | 一种面向图像分类任务的多层字典学习方法 | |
CN114067155B (zh) | 基于元学习的图像分类方法、装置、产品及存储介质 | |
CN111310820A (zh) | 基于交叉验证深度cnn特征集成的地基气象云图分类方法 | |
CN117079017A (zh) | 可信的小样本图像识别分类方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |