CN116109901A - 自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质 - Google Patents
自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质 Download PDFInfo
- Publication number
- CN116109901A CN116109901A CN202310133742.0A CN202310133742A CN116109901A CN 116109901 A CN116109901 A CN 116109901A CN 202310133742 A CN202310133742 A CN 202310133742A CN 116109901 A CN116109901 A CN 116109901A
- Authority
- CN
- China
- Prior art keywords
- task
- learner
- module
- small sample
- meta
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Multimedia (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Electrically Operated Instructional Devices (AREA)
- Machine Translation (AREA)
Abstract
本发明提供一种自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质,通过在小样本任务集合上结合元学习器和调制模块初始化任务学习器,并在每个任务的支持集上基于自适应学习率模块和高斯动量丢失模块更新任务学习器,再基于训练后的任务学习器对查询集样本进行识别并根据识别结果训练元学习器、调制模块及自适应学习率模块;本发明利用任务特定知识,为任务学习器自适应的生成任务特定的初始化参数和学习率张量,并且基于高斯动量丢失模块在任务学习器更新时添加正则项以缓解过拟合现象,提升了元学习模型在小样本任务的性能。
Description
技术领域
本发明涉及机器学习小样本学习领域,特别是涉及一种自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质。
背景技术
元学习(Meta-Learning,ML)是小样本学习问题的有效解决方案,它能够自动学习跨任务的知识。元学习器从每个训练任务中获取先验知识,引导学习器适应新的任务,即“learning to learn”。更具体地说,基于优化的ML可以看作是双层优化过程,由两层循环构成。在任务适应层面上(内循环),任务学习器需要用特定任务的知识快速适应当前任务;在元适应层面上(外循环),元学习器需要缓慢地学习跨任务的知识,然后将知识反馈给任务学习器。经过元训练阶段得到的模型具有良好的任务泛化特性,即模型只需简单的几个优化步骤就能快速适应目标任务,在目标任务上取得良好的分类精度。其中,一个经典的方法是模型无关元学习(Model-Agnostic Meta-Learning,MAML),在MAML中元学习器为任务学习器提供了一个具有良好任务泛化性能的、任务间共享的初始化参数。
在MAML的基础上,一些梯度预处理方法被提出,其中扭曲梯度下降方法(WarpedGradient Descent,WarpGrad)在小样本学习问题中显示了优越的效果。在任务适应过程中反向传播计算梯度时,WarpGrad向模型网络中添加了扭曲层用于对梯度进行非线性预处理,其中扭曲层的参数作为元参数的一部分与任务共享的初始化参数一起在外部循环中获得,从而能够有效捕获损和利用损失面的跨任务几何信息。此外,一些方法,如元随机梯度下降(Meta-SGD)和元曲率(Meta-Curvature,MC),通过跨任务元训练自适应地搜索学习率张量,基于此对梯度进行线性预处理,这里学习率张量就是线性梯度预处理参数。
目前基于梯度预处理的元学习方法都只关注任务共享知识,忽视了任务特定知识的利用,即在任务适应过程中没有对共享的初始化参数和梯度预处理参数进行更新以适应当前任务。此外,梯度预处理增加了参数数量,带来了显著的过拟合风险。
发明内容
鉴于以上所述现有技术的缺点,本发明的目的在于提供一种自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质,用于解决以上现有技术问题。
为实现上述目的及其他相关目的,本发明提供一种自适应正则化扭曲梯度下降的小样本元学习方法,所述方法包括:基于小样本任务集合中每个任务的支持集,利用元学习器以及调制模块初始化对应各任务的任务学习器;其中,每个支持集中包括:多个支持集样本;利用自适应学习率模块和高斯动量丢失模块获取每个任务学习器的优化步长和优化方向,以对初始化的各任务学习器进行一或多次更新;利用更新后的各任务学习器对所述小样本任务集合中每个任务的查询集进行分类,并根据分类结果更新所述元学习器、调制模块以及自适应学习率模块;其中,每个查询集中包括:多个查询集样本。
于本发明的一实施例中,所述基于小样本任务集合中每个任务的支持集,利用元学习器以及调制模块初始化对应各任务的任务学习器包括:利用所述元学习器基于对应的任务间共享初始化参数对每个支持集的支持集样本进行分类,并计算在每个任务上的分类损失和梯度;基于所述调制模块,根据计算获得的分类损失和梯度生成每个任务的任务学习器的调制参数,以获得每个任务任务特定的初始化参数;利用各初始化参数,初始化各任务学习器。
于本发明的一实施例中,所述根据计算获得的分类损失和梯度生成每个任务的任务学习器的调制参数,以获得每个任务的任务特定的初始化参数包括:基于输入的在每个任务上的分类损失和梯度生成每个任务的任务学习器的调制参数;利用各调制参数对对应的任务间共享初始化参数进行调整,生成每个任务任务特定的初始化参数。
于本发明的一实施例中,所述利用自适应学习率模块和高斯动量丢失模块获取每个任务学习器的优化步长和优化方向,以对初始化的各任务学习器进行一或多次更新包括:利用自适应学习率模块和高斯动量丢失模块执行一或次任务学习器更新流程,以对初始化的各任务学习器进行一或多次更新;其中,所述任务学习器更新流程包括:利用初始化的各任务学习器分别对对应的支持集的支持集样本进行分类,并计算在每个任务上的分类损失和扭曲梯度;基于自适应学习率模块,根据每个任务上的分类损失和扭曲梯度生成作为各任务优化步长的自适应学习率张量,以进行各任务任务特定的线性梯度预处理;基于高斯动量丢失模块,根据每个任务上的分类损失和扭曲梯度获取动量并向动量中引入服从高斯分布的噪声,生成作为各任务优化方向的正则化后的扭曲梯度,以对各任务学习器的更新过程进行正则化;基于各任务的优化步长以及优化方向对初始化的各任务学习器进行更新。
于本发明的一实施例中,所述利用更新后的各任务学习器对所述小样本任务集合中每个任务的查询集进行分类,并根据分类结果更新所述元学习器、调制模块以及自适应学习率模块包括:基于更新后的各任务学习器对每个任务的查询集的查询样本进行分类并计算分类损失;基于每个任务学习器所对应的分类损失进行梯度反向传播,并更新所述元学习器、调制模块、自适应学习率模块。
于本发明的一实施例中,所述基于每个任务学习器所对应的分类损失进行梯度反向传播,并更新所述元学习器、调制模块、自适应学习率模块包括:基于每个任务学习器所对应的分类损失,计算用于元模型更新的元损失;基于所述元损失计算用于元参数更新的元梯度,并根据SGD对元学习器的元学习器参数、调制模块的调制模块参数以及自适应学习率模块的自适应学习率模块参数进行更新。
于本发明的一实施例中,所述任务学习器以及元学习器分别采用4层卷积网络,所述调制模块采用多层感知机以及所述自适应学习率模块采用LSTM网络。
为实现上述目的及其他相关目的,本发明提供一种自适应正则化扭曲梯度下降的小样本元学习系统,所述系统包括:所述系统包括:小样本任务集合模块、每个任务对应的任务学习器、元学习器、调制模块、自适应学习率模块以及高斯动量丢失模块;其中,所述小样本任务集合模块,用于储存小样本任务集合中每个任务的支持集以及查询集;其中,每个支持集用于训练每个任务的任务学习器;每个查询集用于训练所述元学习器、调制模块和自适应学习率模块;所述元学习器,用于学习提取小样本任务集合中支持集的支持集样本的样本特征并分类,并将其对应的元学习器参数分别作为各任务学习器的共享初始化参数;所述调制模块,用于根据所述元学习器在每个任务的支持集上的分类结果获取调制参数,以生成更适合每个任务任务特定的任务学习器初始化参数;每个任务学习器,用于对输入的对应任务的样本进行特征提取和分类;所述自适应学习率模块,用于捕获每个任务上的局部损失面信息,并生成作为各任务优化步长的自适应学习率张量,以进行各任务任务特定的线性梯度预处理;其中,所述局部损失面信息包括:分类损失和扭曲梯度;所述高斯动量丢失模块,用于根据每个任务上的局部损失面信息,向动量中引入服从高斯分布的噪声,生成作为各任务优化方向的正则化后的扭曲梯度,以对各任务学习器的更新过程进行正则化。
为实现上述目的及其他相关目的,本发明提供一种自适应正则化扭曲梯度下降的小样本元学习终端,包括:一或多个存储器及一或多个处理器;所述一或多个存储器,用于存储计算机程序;所述一或多个处理器,连接所述存储器,用于运行所述计算机程序以执行所述自适应正则化扭曲梯度下降的小样本元学习方法。
为实现上述目的及其他相关目的,本发明提供一种计算机可读存储介质,存储有计算机程序,所述计算机程序被一个或多个处理器运行时执行所述自适应正则化扭曲梯度下降的小样本元学习方法。
如上所述,本发明是一种自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质,具有以下有益效果:本发明通过在小样本任务集合上结合元学习器和调制模块初始化任务学习器,并在每个任务的支持集上基于自适应学习率模块和高斯动量丢失模块更新任务学习器,基于训练后的任务学习器对查询集样本进行识别并根据识别结果训练元学习器、调制模块及自适应学习率模块;本发明利用任务特定知识,为任务学习器自适应的生成任务特定的初始化参数和学习率张量,并且基于高斯动量丢失模块在任务学习器更新时添加正则项以缓解过拟合现象,提升了元学习模型在小样本任务的性能。
附图说明
图1显示为本发明一实施例中的自适应正则化扭曲梯度下降的小样本元学习方法的流程示意图。
图2显示为本发明一实施例中的自适应正则化扭曲梯度下降的小样本元学习方法的流程示意图。
图3显示为本发明一实施例中的自适应正则化扭曲梯度下降的小样本元学习系统的结构示意图。
图4显示为本发明一实施例中的自适应正则化扭曲梯度下降的小样本元学习系统的结构示意图。
图5显示为本发明一实施例中的自适应正则化扭曲梯度下降的小样本元学习终端的结构示意图。
具体实施方式
以下通过特定的具体实例说明本发明的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本发明的其他优点与功效。本发明还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本发明的精神下进行各种修饰或改变。需说明的是,在不冲突的情况下,以下实施例及实施例中的特征可以相互组合。
需要说明的是,在下述描述中,参考附图,附图描述了本发明的若干实施例。应当理解,还可使用其他实施例,并且可以在不背离本发明的精神和范围的情况下进行机械组成、结构、电气以及操作上的改变。下面的详细描述不应该被认为是限制性的,并且本发明的实施例的范围仅由公布的专利的权利要求书所限定。这里使用的术语仅是为了描述特定实施例,而并非旨在限制本发明。空间相关的术语,例如“上”、“下”、“左”、“右”、“下面”、“下方”、““下部”、“上方”、“上部”等,可在文中使用以便于说明图中所示的一个元件或特征与另一元件或特征的关系。
在通篇说明书中,当说某部分与另一部分“连接”时,这不仅包括“直接连接”的情形,也包括在其中间把其它元件置于其间而“间接连接”的情形。另外,当说某种部分“包括”某种构成要素时,只要没有特别相反的记载,则并非将其它构成要素,排除在外,而是意味着可以还包括其它构成要素。
其中提到的第一、第二及第三等术语是为了说明多样的部分、成分、区域、层及/或段而使用的,但并非限定于此。这些术语只用于把某部分、成分、区域、层或段区别于其它部分、成分、区域、层或段。因此,以下叙述的第一部分、成分、区域、层或段在不超出本发明范围的范围内,可以言及到第二部分、成分、区域、层或段。
再者,如同在本文中所使用的,单数形式“一”、“一个”和“该”旨在也包括复数形式,除非上下文中有相反的指示。应当进一步理解,术语“包含”、“包括”表明存在所述的特征、操作、元件、组件、项目、种类、和/或组,但不排除一个或多个其他特征、操作、元件、组件、项目、种类、和/或组的存在、出现或添加。此处使用的术语“或”和“和/或”被解释为包括性的,或意味着任一个或任何组合。因此,“A、B或C”或者“A、B和/或C”意味着“以下任一个:A;B;C;A和B;A和C;B和C;A、B和C”。仅当元件、功能或操作的组合在某些方式下内在地互相排斥时,才会出现该定义的例外。
本发明的一种自适应正则化扭曲梯度下降的小样本元学习方法,通过在小样本任务集合上结合元学习器和调制模块初始化任务学习器,并在每个任务的支持集上基于自适应学习率模块和高斯动量丢失模块更新任务学习器,基于训练后的任务学习器对查询集样本进行识别并根据识别结果训练元学习器、调制模块及自适应学习率模块;本发明利用任务特定知识,为任务学习器自适应的生成任务特定的初始化参数和学习率张量,并且基于高斯动量丢失模块在任务学习器更新时添加正则项以缓解过拟合现象,提升了元学习模型在小样本任务的性能。
下面以附图为参考,针对本发明的实施例进行详细说明,以便本发明所述技术领域的技术人员能够容易地实施。本发明可以以多种不同形态体现,并不限于此处说明的实施例。
如图1展示本发明实施例中的一种自适应正则化扭曲梯度下降的小样本元学习方法的结构示意图。
所述方法包括:
步骤S101:基于小样本任务集合中每个任务的支持集,利用元学习器以及调制模块初始化对应各任务的任务学习器。
详细来说,所述小样本任务集合中每个任务都包含支持集和查询集,支持集用于训练任务学习器,查询集用于训练元学习器、调制模块和自适应学习率模块。其中,每个支持集中包括:多个支持集样本;每个查询集中包括:多个查询集样本。
在一实施例中,所述元学习器用于学习提取小样本任务集合中支持集的支持集样本的样本特征并分类,并将其对应的元学习器参数分别作为各任务学习器的共享初始化参数;所述调制模块用于根据所述元学习器在每个任务的支持集上的分类结果获取调制参数,以生成更适合每个任务任务特定的任务学习器初始化参数。
结合以上描述,所述基于小样本任务集合中每个任务的支持集,利用元学习器以及调制模块初始化对应各任务的任务学习器包括:
利用所述元学习器基于对应的任务间共享初始化参数对每个支持集的支持集样本进行分类,并计算在每个任务上的分类损失和梯度,以获得反映元学习器在当前任务上性能的梯度信息;
基于所述调制模块,根据计算获得的分类损失和梯度生成每个任务的任务学习器的调制参数,以获得每个任务任务特定的初始化参数;
利用各初始化参数,初始化各任务学习器。
在一实施例中,所述根据计算获得的分类损失和梯度生成每个任务的任务学习器的调制参数,以获得每个任务的任务特定的初始化参数包括:
基于输入的在每个任务上的分类损失和梯度生成每个任务的任务学习器的调制参数;
利用各调制参数对对应的任务间共享初始化参数进行调整,生成每个任务任务特定的初始化参数。
在一具体实施例中,基于任务间共享初始化参数θ计算其在每个任务Ti上的损失和梯度,并将其输入到由v参数化的调制网络πv生成调制参数vi:其中,调制参数vi的计算公式如下:
根据所述生成的调制参数对任务间共享初始化参数进行调制生成任务特定的初始化参数,并初始化任务学习器,该过程可以描述为:
其中,k表示基于支持集对任务学习器的更新次数,故此时k=0,且θi,k就为任务Ti的任务学习器初始化参数,在任务适应过程中将对θi,k和k进行更新。
在一实施例中,所述任务学习器的预设工具为4层卷积网络,述元学习器的预设工具为4层卷积网络,所述调制模块的预设工具为多层感知机,所述自适应学习率模块的预设工具为LSTM,所述分类损失的预设工具为交叉熵分类损失。
步骤S102:利用自适应学习率模块和高斯动量丢失模块获取每个任务学习器的优化步长和优化方向,以对初始化的各任务学习器进行一或多次更新。
在一实施例中,所述自适应学习率模块用于捕获每个任务上的局部损失面信息,并生成作为各任务优化步长的自适应学习率张量,以进行各任务任务特定的线性梯度预处理;所述高斯动量丢失模块用于根据每个任务上的局部损失面信息,向动量中引入服从高斯分布的噪声,生成作为各任务优化方向的正则化后的扭曲梯度,以对各任务学习器的更新过程进行对应正则化。
结合以上内容,步骤S102包括:
利用自适应学习率模块和高斯动量丢失模块执行一或次任务学习器更新流程,以对初始化的各任务学习器进行一或多次更新;
其中,所述任务学习器更新流程包括:
利用初始化的各任务学习器分别对对应的支持集的支持集样本进行分类,并计算在每个任务上的分类损失和扭曲梯度;
基于自适应学习率模块,根据每个任务上的分类损失和扭曲梯度生成作为各任务优化步长的自适应学习率张量,以进行各任务任务特定的线性梯度预处理;
基于高斯动量丢失模块,根据每个任务上的分类损失和扭曲梯度获取动量并向动量中引入服从高斯分布的噪声,生成作为各任务优化方向的正则化后的扭曲梯度,以对各任务学习器的更新过程进行正则化;
基于各任务的优化步长以及优化方向对初始化的各任务学习器进行更新,以使任务学习器适应当前任务。
需要说明的是,对初始化的各任务学习器进行更新的次数可根据需求而设定。
在一具体实施例中,基于所述任务学习器θi,k对支持集train(Ti)中样本进行分类,计算对应的分类损失和扭曲梯度gi,k,其中计算公式为:
基于自适应学习率模块生成任务自适应学习率张量αi,k的过程可以描述为:
基于LSTM网络对任务局部损失面几何信息的捕获,获得的任务自适应学习率张量αi,k能够有效帮助任务学习器更快的适应到当前任务上。
基于高斯动量丢失模块生成正则化后的扭曲梯度作为优化方向,基于所述得到的梯度,动量的更新公式为:
mi,k=lmi,k-1+(1-l)gi,k; (7)
其中,l是上一步的动量的权重,预设值为0.95。基于动量丢失概率p,从预定义的高斯分布上采样噪声正则化项n:
其中,p的预设值为0.2。基于所述得到的动量mi,k和丢失因子n即可得到正则化后的优化方向:
m′i,k=mi,ken; (9)
其中,e是逐元素乘积算子。
基于所述得到的自适应学习率张量αi,k和优化方向m′i,k更新任务学习器,该过程可以描述为:
θi,k=θi,k-1-αi,kem′i,k; (10)
所述任务学习器的更新过程中,自适应学习率张量αi,k用于基于任务特定知识对梯度进行了线性预处理,优化方向m′i,k通过引入高斯噪声有效地缓解了过拟合现象。
步骤S103:利用更新后的各任务学习器对所述小样本任务集合中每个任务的查询集进行分类,并根据分类结果更新所述元学习器、调制模块以及自适应学习率模块。
详细来说,每个任务学习器用于对输入的对应任务的样本进行特征提取和分类;每个查询集中包括:多个查询集样本。
在一实施例中,步骤S103包括:
基于更新后的各任务学习器对每个任务的查询集的查询样本进行分类并计算分类损失;
基于每个任务学习器所对应的分类损失进行梯度反向传播,并更新所述元学习器、调制模块、自适应学习率模块。
在一实施例中,所述基于每个任务学习器所对应的分类损失进行梯度反向传播,并更新所述元学习器、调制模块、自适应学习率模块包括:
基于每个任务学习器所对应的分类损失,计算用于元模型更新的元损失;
基于所述元损失计算用于元参数更新的元梯度,并根据SGD(随机梯度下降算法)对元学习器的元学习器参数、调制模块的调制模块参数以及自适应学习率模块的自适应学习率模块参数进行更新。
在具体的实施例中,基于得到的任务适应后的任务学习器参数θi,K,对任务Ti中所有查询集图像进行分类,根据真实标签计算在查询集样本上的分类损失对小样本任务集合中每个任务分别执行上述步骤,得到在每个任务查询集上的损失计算用于元模型更新的元损失:
其中,β为外环学习率,预设值为0.001。
为了更好的说明上述自适应正则化扭曲梯度下降的小样本元学习方法,本发明提供以下具体实施例。
实施例1:一种自适应正则化扭曲梯度下降的小样本元学习方法。图2为自适应正则化扭曲梯度下降的小样本元学习方法的流程示意图。
本发明实施例基于元学习方法包含以下步骤:
步骤S1:获取小样本训练数据集和测试数据集。
步骤S3:基于任务间共享初始化参数对支持集样本进行分类并计算分类损失和梯度,将其输入到调制模块中生成任务特定的初始化参数并初始化任务学习器。其具体步骤包括:
步骤S3.1:基于任务间共享的初始化参数θ计算其在每个任务Ti上的损失和梯度,并将其输入到由v参数化的调制网络πv生成调制参数vi:
步骤S3.2:根据所述生成的调制参数对任务间共享的初始化参数进行调制生成任务特定的初始化参数,并初始化任务学习器,该过程可以描述为:
其中,k表示基于支持集对任务学习器的更新次数,故此时k=0,且θi,k就为任务Ti的任务学习器初始化参数,在任务适应过程中将对θi,k和k进行更新。
步骤S4:基于所述任务学习器对支持集样本进行分类,计算对应的分类损失和扭曲梯度。分别通过自适应学习率模块和高斯动量丢失模块生成任务自适应的学习率张量和正则化后的扭曲梯度作为优化方向,基于所述得到的学习率和优化方向更新任务学习器。重复上述任务学习器更新过程以适应每个任务。其具体步骤包括:
步骤S4.1:基于所述任务学习器θi,k对支持集train(Ti)中样本进行分类,计算对应的分类损失和扭曲梯度gi,k:
步骤S4.2:通过自适应学习率模块生成任务自适应的学习率张量,高斯动量丢失模块生成正则化后的扭曲梯度作为优化方向,基于所述得到的学习率和优化方向更新任务学习器。其具体步骤包括:
步骤S4.2.1:基于自适应学习率模块生成任务自适应学习率张量αi,k的过程可以描述为:
其中,基于LSTM网络对任务局部损失面几何信息的捕获,所述获得的任务自适应学习率张量αi,k能够有效帮助任务学习器更快的适应到当前任务上。
步骤S4.2.2:基于高斯动量丢失模块生成正则化后的扭曲梯度作为优化方向。基于所述得到的梯度,动量的更新公式为:
mi,k=lmi,k-1+(1-l)gi,k; (7)
其中,l是上一步的动量的权重,预设值为0.95。基于动量丢失概率p,从预定义的高斯分布上采样噪声正则化项n:
其中,p的预设值为0.2。基于所述得到的动量mi,k和丢失因子n即可得到正则化后的优化方向:
m′i,k=mi,ken; (9)
其中,e是逐元素乘积算子。
步骤S4.2.3:基于所述得到的自适应学习率张量αi,k和优化方向m′i,k更新任务学习器,该过程可以描述为:
θi,k=θi,k-1-αi,kem′i,k (10)
所述任务学习器的更新过程中,自适应学习率张量αi,k用于基于任务特定知识对梯度进行了线性预处理,优化方向m′i,k通过引入高斯噪声有效地缓解了过拟合现象。
步骤S4.3,k=k+1,重复步骤S4.1至S4.2K次,直至k=K。
步骤S5:基于所述更新后的任务学习器对查询集样本进行分类,根据分类结果计算交叉熵分类损失,基于该分类损失更新元学习器参数。其具体步骤包括:
步骤S5.1:基于S4所述得到的任务适应后的任务学习器参数θi,K,对任务Ti中所有查询集图像进行分类,根据真实标签计算在查询集样本上的分类损失对小样本任务集合中每个任务分别执行上述步骤,得到在每个任务查询集上的损失计算用于元模型更新的元损失:
其中,β为外环学习率,预设值为0.001。
步骤S6:重复步骤S2至S5,直到元模型的参数收敛。
步骤S7:从小样本测试数据集上采样目标小样本任务集合,进行S3-S4步骤获取更新后的任务学习器参数,基于所述任务学习器对目标任务查询集样本进行分类。
本实施例基于小样本任务集合,通过调制模块对元学习器进行任务定制生成任务学习器,并在每个任务的支持集上基于自适应学习率模块和高斯动量丢失模块更新任务学习器。即在本实施例中,元学习器的参数不直接用于初始化任务学习器,而是通过调制模块根据元学习器在任务支持集上的性能生成调制参数,并对元学习器参数进行调制以初始化任务学习器,并且在任务学习器的更新过程中,通过自适应学习率张量进行了线性梯度预处理,通过引入高斯噪声进行了优化方向的正则化,因此任务学习器及其更新过程能够随着任务进行自适应调整,从而为任务适应过程提供了更好的指导,提高了小样本元学习的性能,模型的任务泛化性更强。
与上述实施例原理相似的是,本发明提供一种自适应正则化扭曲梯度下降的小样本元学习系统。
以下结合附图提供具体实施例:
如图3展示本发明实施例中的一种自适应正则化扭曲梯度下降的小样本元学习系统的结构示意图。
所述系统包括:小样本任务集合模块1、每个任务对应的任务学习器2、元学习器3、调制模块4、自适应学习率模块5以及高斯动量丢失模块6;需要说明的是,图中仅以一个任务学习器为例。
其中,所述小样本任务集合模块1,用于储存小样本任务集合中每个任务的支持集以及查询集;其中,每个支持集用于训练每个任务的任务学习器;每个查询集用于训练所述元学习器、调制模块和自适应学习率模块;
所述元学习器3,用于学习提取小样本任务集合中支持集的支持集样本的样本特征并分类,并将其对应的元学习器参数分别作为各任务学习器的共享初始化参数;
所述调制模块4,用于根据所述元学习器在每个任务的支持集上的分类结果获取调制参数,以生成更适合每个任务任务特定的任务学习器初始化参数;
每个任务学习器2,用于对输入的对应任务的样本进行特征提取和分类;
所述自适应学习率模块5,用于捕获每个任务上的局部损失面信息,并生成作为各任务优化步长的自适应学习率张量,以进行各任务任务特定的线性梯度预处理;其中,所述局部损失面信息包括:分类损失和扭曲梯度;
所述高斯动量丢失模块6,用于根据每个任务上的局部损失面信息,向动量中引入服从高斯分布的噪声,生成作为各任务优化方向的正则化后的扭曲梯度,以对各任务学习器的更新过程进行正则化。
其中,采用所述自适应正则化扭曲梯度下降的小样本元学习系统进行小样本学习的方法包括:
步骤一:基于小样本任务集合模块1中的每个任务的支持集,利用元学习器3以及调制模块4初始化对应各任务的任务学习器2;
具体的步骤包括:从小样本任务集合模块1获取小样本任务集合中每个任务的支持集;利用所述元学习器3基于对应的任务间共享初始化参数对每个支持集的支持集样本进行分类,并计算在每个任务上的分类损失和梯度;基于所述调制模块4,根据计算获得的分类损失和梯度生成每个任务的任务学习器2的调制参数,以获得每个任务任务特定的初始化参数,并利用各初始化参数,初始化各任务学习器。
步骤二:利用自适应学习率模块5和高斯动量丢失模块6获取每个任务学习器2的优化步长和优化方向,以对初始化的各任务学习器2进行一或多次更新。
具体的步骤包括:利用初始化的各任务学习器2分别对对应的支持集的支持集样本进行分类,并计算在每个任务上的分类损失和扭曲梯度;基于自适应学习率模块5,根据每个任务上的分类损失和扭曲梯度生成作为各任务优化步长的自适应学习率张量,以进行各任务任务特定的线性梯度预处理;基于高斯动量丢失模块6,根据每个任务上的分类损失和扭曲梯度获取动量并向动量中引入服从高斯分布的噪声,生成作为各任务优化方向的正则化后的扭曲梯度,以对各任务学习器2的更新过程进行正则化;基于各任务的优化步长以及优化方向对初始化的各任务学习器进行更新。
步骤三:利用更新后的各任务学习器2对所述小样本任务集合模块1中每个任务的查询集进行分类,并根据分类结果更新所述元学习器3、调制模块4以及自适应学习率模块5。
具体的步骤包括:当重复更新步骤多次后,利用更新后的各任务学习器2对所述小样本任务集合中每个任务的查询集进行分类,并根据分类结果更新所述元学习器3、调制模块4以及自适应学习率模块5。
由于该自适应正则化扭曲梯度下降的小样本元学习系统的实现原理已在前述实施例中进行了叙述,因此此处不作重复赘述。
为了更好的说明上述自适应正则化扭曲梯度下降的小样本元学习系统,本发明提供以下具体实施例。
实施例1:一种自适应正则化扭曲梯度下降的小样本元学习系统。图4为自适应正则化扭曲梯度下降的小样本元学习系统的结构示意图。
自适应正则化扭曲梯度下降的小样本元学习系统装置包括:小样本任务集合,元学习器,调制模块,任务学习器,自适应学习率模块和高斯动量丢失模块;
其中,所述小样本任务集合中每个任务都包含支持集和查询集,支持集用于训练任务学习器,查询集用于训练元学习器、调制模块和自适应学习率模块。所述元学习器,用于学习提取小样本任务集合中样本特征并分类,以作为任务学习器的共享初始化参数;所述调制模块,用于根据元学习器在当前任务支持集上的分类结果获取调制参数,以生成更适合当前任务特定的任务学习器初始化参数;所述任务学习器,用于学习当前任务的先验知识,以对输入任务中的样本进行特征提取和分类;所述自适应学习率模块,用于捕获任务局部损失面信息,以生成更适合当前任务适应过程的学习率张量用于梯度线性预处理;所述高斯动量丢失模块,用于向动量中引入服从高斯分布的噪声,以对任务学习器的更新过程进行正则化,有效缓解过拟合现象。
如图5展示本发明实施例中的自适应正则化扭曲梯度下降的小样本元学习终端50的结构示意图。
所述自适应正则化扭曲梯度下降的小样本元学习终端50包括:存储器51及处理器52。所述存储器51用于存储计算机程序;所述处理器52运行计算机程序,实现如图1所述的自适应正则化扭曲梯度下降的小样本元学习方法。
可选的,所述存储器51的数量均可以是一或多个,所述处理器52的数量均可以是一或多个,而图5中均以一个为例。
可选的,所述自适应正则化扭曲梯度下降的小样本元学习终端50中的处理器52会按照如图1所述的步骤,将一个或多个以应用程序的进程对应的指令加载到存储器51中,并由处理器52来运行存储在第一存储器51中的应用程序,从而实现如图1所述自适应正则化扭曲梯度下降的小样本元学习方法中的各种功能。
可选的,所述存储器51,可能包括但不限于高速随机存取存储器、非易失性存储器。例如一个或多个磁盘存储设备、闪存设备或其他非易失性固态存储设备;所述处理器52,可能包括但不限于中央处理器(Central Processing Unit,简称CPU)、网络处理器(Network Processor,简称NP)等;还可以是数字信号处理器(Digital SignalProcessing,简称DSP)、专用集成电路(Application Specific Integrated Circuit,简称ASIC)、现场可编程门阵列(Field-Programmable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
可选的,所述处理器52可以是通用处理器,包括中央处理器(Central ProcessingUnit,简称CPU)、网络处理器(Network Processor,简称NP)等;还可以是数字信号处理器(Digital Signal Processing,简称DSP)、专用集成电路(Application SpecificIntegrated Circuit,简称ASIC)、现场可编程门阵列(Field-Programmable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
本发明还提供计算机可读存储介质,存储有计算机程序,所述计算机程序运行时实现如图1所示的自适应正则化扭曲梯度下降的小样本元学习方法。所述计算机可读存储介质可包括,但不限于,软盘、光盘、CD-ROM(只读光盘存储器)、磁光盘、ROM(只读存储器)、RAM(随机存取存储器)、EPROM(可擦除可编程只读存储器)、EEPROM(电可擦除可编程只读存储器)、磁卡或光卡、闪存、或适于存储机器可执行指令的其他类型的介质/机器可读介质。所述计算机可读存储介质可以是未接入计算机设备的产品,也可以是已接入计算机设备使用的部件。
综上所述,本发明的自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质,通过在小样本任务集合上结合元学习器和调制模块初始化任务学习器,并在每个任务的支持集上基于自适应学习率模块和高斯动量丢失模块更新任务学习器,基于训练后的任务学习器对查询集样本进行识别并根据识别结果训练元学习器、调制模块及自适应学习率模块;本发明利用任务特定知识,为任务学习器自适应的生成任务特定的初始化参数和学习率张量,并且基于高斯动量丢失模块在任务学习器更新时添加正则项以缓解过拟合现象,提升了元学习模型在小样本任务的性能。所以,本发明有效克服了现有技术中的种种缺点而具高度产业利用价值。
上述实施例仅示例性说明本发明的原理及其功效,而非用于限制本发明。任何熟悉此技术的人士皆可在不违背本发明的精神及范畴下,对上述实施例进行修饰或改变。因此,但凡所属技术领域中具有通常知识者在未脱离本发明所揭示的精神与技术思想下所完成的一切等效修饰或改变,仍应由本发明的权利要求所涵盖。
Claims (10)
1.一种自适应正则化扭曲梯度下降的小样本元学习方法,其特征在于,所述方法包括:
基于小样本任务集合中每个任务的支持集,利用元学习器以及调制模块初始化对应各任务的任务学习器;其中,每个支持集中包括:多个支持集样本;
利用自适应学习率模块和高斯动量丢失模块获取每个任务学习器的优化步长和优化方向,以对初始化的各任务学习器进行一或多次更新;
利用更新后的各任务学习器对所述小样本任务集合中每个任务的查询集进行分类,并根据分类结果更新所述元学习器、调制模块以及自适应学习率模块;其中,每个查询集中包括:多个查询集样本。
2.根据权利要求1中所述的自适应正则化扭曲梯度下降的小样本元学习方法,其特征在于,所述基于小样本任务集合中每个任务的支持集,利用元学习器以及调制模块初始化对应各任务的任务学习器包括:
利用所述元学习器基于对应的任务间共享初始化参数对每个支持集的支持集样本进行分类,并计算在每个任务上的分类损失和梯度;
基于所述调制模块,根据计算获得的分类损失和梯度生成每个任务的任务学习器的调制参数,以获得每个任务任务特定的初始化参数;
利用各初始化参数,初始化各任务学习器。
3.根据权利要求2中所述的自适应正则化扭曲梯度下降的小样本元学习方法,其特征在于,所述根据计算获得的分类损失和梯度生成每个任务的任务学习器的调制参数,以获得每个任务的任务特定的初始化参数包括:
基于输入的在每个任务上的分类损失和梯度生成每个任务的任务学习器的调制参数;
利用各调制参数对对应的任务间共享初始化参数进行调整,生成每个任务任务特定的初始化参数。
4.根据权利要求1中所述的自适应正则化扭曲梯度下降的小样本元学习方法,其特征在于,所述利用自适应学习率模块和高斯动量丢失模块获取每个任务学习器的优化步长和优化方向,以对初始化的各任务学习器进行一或多次更新包括:
利用自适应学习率模块和高斯动量丢失模块执行一或次任务学习器更新流程,以对初始化的各任务学习器进行一或多次更新;
其中,所述任务学习器更新流程包括:
利用初始化的各任务学习器分别对对应任务的支持集的支持集样本进行分类,并计算在每个任务上的分类损失和扭曲梯度;
基于自适应学习率模块,根据每个任务上的分类损失和扭曲梯度生成作为各任务优化步长的自适应学习率张量,以进行各任务任务特定的线性梯度预处理;
基于高斯动量丢失模块,根据每个任务上的分类损失和扭曲梯度获取动量并向动量中引入服从高斯分布的噪声,生成作为各任务优化方向的正则化后的扭曲梯度,以对各任务学习器的更新过程进行正则化;
基于各任务的优化步长以及优化方向对初始化的各任务学习器进行更新。
5.根据权利要求1中所述的自适应正则化扭曲梯度下降的小样本元学习方法,其特征在于,所述利用更新后的各任务学习器对所述小样本任务集合中每个任务的查询集进行分类,并根据分类结果更新所述元学习器、调制模块以及自适应学习率模块包括:
基于更新后的各任务学习器对每个任务的查询集的查询样本进行分类并计算分类损失;
基于每个任务学习器所对应的分类损失进行梯度反向传播,并更新所述元学习器、调制模块、自适应学习率模块。
6.根据权利要求5中所述的自适应正则化扭曲梯度下降的小样本元学习方法,其特征在于,所述基于每个任务学习器所对应的分类损失进行梯度反向传播,并更新所述元学习器、调制模块、自适应学习率模块包括:
基于每个任务学习器所对应的分类损失,计算用于元模型更新的元损失;
基于所述元损失计算用于元参数更新的元梯度,并根据SGD对元学习器的元学习器参数、调制模块的调制模块参数以及自适应学习率模块的自适应学习率模块参数进行更新。
7.根据权利要求1中所述的自适应正则化扭曲梯度下降的小样本元学习方法,其特征在于,所述任务学习器以及元学习器分别采用4层卷积网络,所述调制模块采用多层感知机以及所述自适应学习率模块采用LSTM网络。
8.一种自适应正则化扭曲梯度下降的小样本元学习系统,其特征在于,所述系统包括:
小样本任务集合模块、每个任务对应的任务学习器、元学习器、调制模块、自适应学习率模块以及高斯动量丢失模块;
其中,所述小样本任务集合模块,用于储存小样本任务集合中每个任务的支持集以及查询集;其中,每个支持集用于训练每个任务的任务学习器;每个查询集用于训练所述元学习器、调制模块和自适应学习率模块;
所述元学习器,用于学习提取小样本任务集合中支持集的支持集样本的样本特征并分类,并将其对应的元学习器参数分别作为各任务学习器的共享初始化参数;
所述调制模块,用于根据所述元学习器在每个任务的支持集上的分类结果获取调制参数,以生成更适合每个任务任务特定的任务学习器初始化参数;
每个任务学习器,用于对输入的对应任务的样本进行特征提取和分类;
所述自适应学习率模块,用于捕获每个任务上的局部损失面信息,并生成作为各任务优化步长的自适应学习率张量,以进行各任务任务特定的线性梯度预处理;其中,所述局部损失面信息包括:分类损失和扭曲梯度;
所述高斯动量丢失模块,用于根据每个任务上的局部损失面信息,向动量中引入服从高斯分布的噪声,生成作为各任务优化方向的正则化后的扭曲梯度,以对各任务学习器的更新过程进行正则化。
9.一种自适应正则化扭曲梯度下降的小样本元学习终端,其特征在于,包括:一或多个存储器及一或多个处理器;
所述一或多个存储器,用于存储计算机程序;
所述一或多个处理器,连接所述存储器,用于运行所述计算机程序以执行如权利要求1至7所述的方法。
10.一种计算机可读存储介质,其特征在于,存储有计算机程序,所述计算机程序被一个或多个处理器运行时执行如权利要求1至7中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310133742.0A CN116109901A (zh) | 2023-02-17 | 2023-02-17 | 自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310133742.0A CN116109901A (zh) | 2023-02-17 | 2023-02-17 | 自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116109901A true CN116109901A (zh) | 2023-05-12 |
Family
ID=86265354
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310133742.0A Pending CN116109901A (zh) | 2023-02-17 | 2023-02-17 | 自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116109901A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116737939A (zh) * | 2023-08-09 | 2023-09-12 | 恒生电子股份有限公司 | 元学习方法、文本分类方法、装置、电子设备及存储介质 |
-
2023
- 2023-02-17 CN CN202310133742.0A patent/CN116109901A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116737939A (zh) * | 2023-08-09 | 2023-09-12 | 恒生电子股份有限公司 | 元学习方法、文本分类方法、装置、电子设备及存储介质 |
CN116737939B (zh) * | 2023-08-09 | 2023-11-03 | 恒生电子股份有限公司 | 元学习方法、文本分类方法、装置、电子设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2020140422A1 (en) | Neural network for automatically tagging input image, computer-implemented method for automatically tagging input image, apparatus for automatically tagging input image, and computer-program product | |
US9990558B2 (en) | Generating image features based on robust feature-learning | |
Hu et al. | A survey on online feature selection with streaming features | |
Wilmanski et al. | Modern approaches in deep learning for SAR ATR | |
Alhassan et al. | Brain tumor classification in magnetic resonance image using hard swish-based RELU activation function-convolutional neural network | |
De Rosa et al. | Handling dropout probability estimation in convolution neural networks using meta-heuristics | |
JP2022542639A (ja) | 生物学関連のデータを処理するための機械学習アルゴリズムをトレーニングするためのシステムおよび方法、顕微鏡ならびにトレーニングされた機械学習アルゴリズム | |
JP2018513491A (ja) | 2部グラフラベルの調査によるファイングレイン画像分類 | |
Dixit et al. | Texture classification using convolutional neural network optimized with whale optimization algorithm | |
Ma et al. | Lightweight attention convolutional neural network through network slimming for robust facial expression recognition | |
Karalas et al. | Deep learning for multi-label land cover classification | |
CN113256592B (zh) | 图像特征提取模型的训练方法、系统及装置 | |
US11580384B2 (en) | System and method for using a deep learning network over time | |
Aamir et al. | A deep contractive autoencoder for solving multiclass classification problems | |
CN116109901A (zh) | 自适应正则化扭曲梯度下降的小样本元学习方法、系统、终端及介质 | |
De Silva et al. | Wavelet based edge feature enhancement for convolutional neural networks | |
Rasheed et al. | Brain tumor classification from MRI using image enhancement and convolutional neural network techniques | |
Kahraman et al. | Classification of defective fabrics using capsule networks | |
Roy et al. | L3DMC: Lifelong Learning using Distillation via Mixed-Curvature Space | |
Georgakopoulos et al. | A novel adaptive learning rate algorithm for convolutional neural network training | |
Reddy et al. | Classification of health care products using hybrid CNN-LSTM model | |
Escorcia-Gutierrez et al. | Intelligent sine cosine optimization with deep transfer learning based crops type classification using hyperspectral images | |
Pias et al. | Perfect storm: DSAs embrace deep learning for GPU-based computer vision | |
Glory Precious et al. | Deployment of a mobile application using a novel deep neural network and advanced pre-trained models for the identification of brain tumours | |
Mormille et al. | Introducing inductive bias on vision transformers through gram matrix similarity based regularization |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |