CN114048843A - 一种基于选择性特征迁移的小样本学习网络 - Google Patents
一种基于选择性特征迁移的小样本学习网络 Download PDFInfo
- Publication number
- CN114048843A CN114048843A CN202111400311.3A CN202111400311A CN114048843A CN 114048843 A CN114048843 A CN 114048843A CN 202111400311 A CN202111400311 A CN 202111400311A CN 114048843 A CN114048843 A CN 114048843A
- Authority
- CN
- China
- Prior art keywords
- data
- training
- meta
- network
- module
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明公开一种基于选择性特征迁移的小样本学习网络,包括元学习模块、ResNet模块、选择性对抗迁移网络模块和自注意力模块;元学习模块包括元训练阶段和元测试阶段,元训练阶段通过在与目标任务相近的任务上学习,训练得到能够作为目标任务初始化起点的预训练模型;元测试阶段是在预训练模型上训练目标任务;ResNet模块采用了层级之间的恒等映射,通过残差学习的方式进行训练;选择性对抗迁移网络模块由一个生成器和若干个鉴别器组成;自注意力模块能够对同一类别的样本点进行均值的计算,求出类别的原型向量,再通过每个样本与原型向量的欧几里得距离求得每个样本的权重,从而辅助模型的训练。
Description
技术领域
本发明涉及人工智能图像识别中的小样本学习,特别是涉及一种基于选择性特征迁移的小样本学习网络。
背景技术
近年来,由于GPU等强大的计算设备、ImageNet等大规模数据集、CNN等先进的模型和算法,人工智能在很多领域都加快了和人类一样的步伐,打败了人类。AlphaGo击败了人类的围棋冠军,ResNet击败了人类对ImageNet的1000类数据的分类率;而在其他领域,人工智能作为高度智能的工具,如语音助手、搜索引擎、自主驾驶汽车、工业机器人等,都进入到了人类的日常生活。
尽管人工智能的繁荣,但在它像人类一样行动之前,它仍有一些重要的任务要做,其中之一就是要从为数不多的数据中快速归纳出执行任务。回想一下,人类可以将自己所学到的东西迅速归纳到新的任务场景中去,快速地将其归纳为新的任务场景。例如,给定一张照片中的一个陌生人,人类可以从大量的照片中轻松地辨认出来。人类可以将过去所学到的东西结合到新的例子中,因此可以快速地概括成新的任务。相比之下,上述成功的应用依赖于从大规模数据中进行详尽的学习。
对于从有限的受监督信息中学习以掌握任务的渴求,出现了一个新的机器学习问题,称为小样本学习(Few-Shot Learning)。当只有一个范本需要学习时,小样本学习也被称为单样本学习问题。小样本学习可以通过整合先验知识来学习有限的受监督信息的新任务。
小样本学习可以帮助减轻工业用途下的大规模收集有标签数据的负担。例如,ResNet对ImageNet数据的1000类的分类率打败了人类的分类率。然而,这是在每个类都有足够的标签图像的情况下。相比之下,人类可以识别大约30000个类,而在这种情况下,收集足够多的类的图像对于机器来说是非常费力的。这几乎是不可能完成的任务。
小样本学习的另一个经典场景是由于某些原因,如隐私、安全或伦理问题等,难以或无法获得被监督的信息的任务。例如,药物发现是发现新分子的特性,从而确定有用的新药的过程。然而,由于可能的毒性、活性低、溶解度低等原因,这些新分子在临床上的实际生物学记录并不多。这使得药物发现任务成为小样本学习问题。
小样本学习的常见方法有三种:基于度量的小样本学习、基于模型的小样本学习的和基于优化的小样本学习。基于度量的小样本学习的核心思想类似于近邻算法(即k-NNN分类器和k-means聚类)和内核密度估计。产生一个核函数,测量两个数据样本之间的相似度。学习好的核函数对于基于度量的小样本学习模型的成功至关重要。度量学习很好地契合了这一意图,因为它的目标是在对象之上学习一个度量或距离函数。一个好的度量的概念是依赖于问题的。它应该代表任务空间中输入数据之间的关系,并促进问题的解决。基于模型的小样本学习模型不对形式做任何假设。相反,它依赖于一个专门为快速学习而设计的模型:一个只需几个训练步骤就能快速更新参数的模型。这种快速的参数更新可以由它的内部架构来实现,也可以由一个元学习模型来控制。深度学习模型通过梯度的反向传播来学习。然而,基于梯度的优化算法既不能应对少量的训练样本,也不能在少量的优化步骤内完成优化。因此亟需一种方法可以调整优化算法,让模型在少量的例子中就能很好的学习。
发明内容
传统的小样本网络中,特征的迁移往往伴随着负迁移的影响,严重损害了模型的性能,本发明的目的是为了克服现有技术中的不足,提供一种基于选择性特征迁移的小样本学习网络,为了训练可以在新的小样本目标任务下,可以快速收敛的小样本学习网络。
本发明的目的是通过以下技术方案实现的:
一种基于选择性特征迁移的小样本学习网络,包括元学习模块、ResNet模块、选择性对抗迁移网络模块和自注意力模块;
元学习模块包括元训练阶段和元测试阶段,元训练阶段通过在与目标任务相近的任务上学习,训练得到能够作为目标任务初始化起点的预训练模型;元测试阶段是在预训练模型上训练目标任务;
ResNet模块采用了层级之间的恒等映射,通过残差学习的方式进行训练;
选择性对抗迁移网络模块由一个生成器网络和若干个判别器网络组成;生成器网络生成的数据通过一个classifier层得到关于每个判别器网络的权重向量,根据权重向量加权,再传递给后续的判别器网络,每个判别器网络判别收到的数据;数据分布相近的类别所对应的判别器网络,其对应的权重会更大,所以特征的迁移由该判别器网络进行,避免了负迁移的现象,实现了提升特征迁移的效果;
自注意力模块能够对同一类别的样本点进行均值的计算,求出类别的原型向量,再通过每个样本与原型向量的欧几里得距离求得每个样本的权重,从而辅助整个小样本学习网络的训练。
进一步的,元训练阶段中收集数据得到数据集,将数据集按类别随机划分为源域和目标域数据;在一轮训练中,首先将源域数据划分为N个类别,每个类别中有K个数据,输入进初始化的小样本学习网络,进行训练;得到的特征向量再和目标域数据一同输入选择性对抗迁移网络模块进行对抗训练,源域的数据经过生成器后通过分类,得到不同判别器的权重,再由判别器对源域和目标域的数据进行区分,计算出的损失通过反向传播的方式,更新整个小样本学习网络;在元测试阶段,将目标任务的数据输入进已经预训练好的预训练模型,得到小样本学习网络的预测分类结果。
与现有技术相比,本发明的技术方案所带来的有益效果是:
本发明提出的基于选择性特征迁移的小样本学习方法可以筛去会造成负迁移效果的样本,提高小样本学习的准确率。小样本学习往往采用元学习的范式进行模型的训练。而元学习往往需要大量的数据进行训练,数据中的许多类别的数据分布和目标任务的相差过多。所以元学习中的特征迁移很多时候往往伴有负迁移,负迁移不仅对小样本学习网络的精度没有任何帮助,还会因为参数的错误学习导致伤害到小样本学习网络的准确性。而本发明采用的基于选择性特征迁移的小样本学习网络可以选择性的进行特征迁移,从而过滤掉会造成负迁移效果的源域样本,有效提升了小样本学习网络的精度。
同时本发明提出的小样本学习方法采用了自注意力的机制,可以有效避免离群点在小样本学习网络预测过程中造成的偏差,提高小样本学习方法的准确率。小样本学习的场景下,训练数据非常的稀少。而现有的小样本学习方法大都通过这稀少的训练数据计算一个类别的原型向量,以用于后续的距离度量的比较。但抽样的小样本数据中,一旦有离群点存在,会造成原型向量严重的偏差,从而影响后续测试阶段,小样本学习网络预测的准确性。本发明在计算原型向量的过程中,采用自注意力的方式,通过同一类别各样本间的距离,修正他们的权重,从而降低离群点的重要性,使计算得到的原型向量更接近真实情况下的类别原型。
附图说明
图1是ResNet模块中的跳跃级联与恒等映射示意图。
图2是整体网络的架构示意图。原始数据首先输入到特征提取网络(ResNet模块)中得到对应的特征向量。然后将特征向量输入后续的模块进行深度学习的训练。自注意力模块将特征向量转化为带权的向量后进行分类任务的学习。选择性迁移网络为不同的源域数据赋予不同的权重,从而减轻负迁移的影响。
具体实施方式
以下结合附图和具体实施例对本发明作进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
本发明提供一种基于选择性特征迁移的小样本学习网络,见图2,包括元学习模块、ResNet模块、选择性对抗迁移网络模块和自注意力模块;
1.元学习模块
在小样本的深度学习场景中,目标任务的样本数量非常稀少,直接在目标任务上构建模型会出现非常严重的过拟合现象,模型的效果无法达到可以实际应用的标准。所以往往采用元学习的范式进行模型的构建。元学习从人类可以快速学习一个新事物的能力中得到启发,收集大量与目标任务相近的任务,在这些任务上进行训练。通过在与目标任务相近的大量任务上学习,训练得到一个泛化能力强,可以作为目标任务初始化起点的预训练模型,这一阶段称作元训练阶段。然后在模型上训练目标任务,这一阶段因为之前模型已经学习到了大量的泛化的特征信息,所以可以在少量迭代后快速拟合,这一阶段称作元测试阶段。
2.ResNet模块
在小样本的深度学习场景中,很容易因为样本量不足的原因出现过拟合的情况,所以无法选用层数非常深,参数量非常巨大的深度模型来进行训练。而从经验来看,深度学习中深度神经网络的深度对模型的性能至关重要,当增加深度神经网络的层数后,深度神经网络可以进行更加复杂的特征模式的提取,所以当深度神经网络的层数更深时理论上可以取得更好的结果。ResNet采用了层级之间的恒等映射(Identity mapping),通过残差学习的方式进行训练,见图1。残差学习相比原始特征的直接学习更加容易。当残差为0时,此时堆积层仅仅做了恒等映射,至少网络性能不会下降,实际上残差不会为0,这也会使得堆积层在输入特征基础上学习到新的特征,从而拥有更好的性能。所以引入ResNet的残差模块,本发明可以构建比传统的深度网络模型更深的神经网络。
3.选择性对抗迁移网络模块(Selective Adversarial Network)
传统的Adversarial Network模型都是一个Generator模块搭配一个Discriminator模块。在两个模块上进行对抗训练,通过Discriminator模块对源域和目标域的数据进行分类判断,而Generator模块学习生成可以混淆Discriminator的数据分布,使源域和目标域的数据分布近似相同,从而完成对数据的迁移。但这样的架构存在一个问题,如果来自源域的数据和来自目标域的数据分布相差过大,数据迁移不会产生特别好的效果,甚至还可能会出现负迁移的现象,伤害模型的性能。而在小样本学习的任务中,往往会收集大量数据,增强元模型的泛化能力,而这些数据中有很多类比的数据分布和目标任务的数据分布相差较多,所以直接在元模型上通过收集的源域数据进行特征迁移可能会造成较严重的负迁移。
所以本实施例采用选择性对抗迁移网络模块(Selective Adversarial Network)进行特征的迁移。选择性对抗迁移网络由一个Generator模块(生成器)和多个Discriminator模块(鉴别器)组成,见图2。Generator模块生成的数据通过一个classifier层得到关于每个Discriminator模块的权重向量,根据权重向量加权,再传递给后续的Discriminator模块,每个Discriminator模块判别特定域的数据。数据分布相近的类别所在的模块,权重会大,所以特征的迁移主要由该模块进行,从而避免了负迁移的现象,达到了提升特征迁移效果的作用。
4.自注意力模块(self-attention)
人类在先验知识的训练中形成了一套机制迫使他们将注意力转移到视线范围内的特定区域,这就是注意力机制(Attention Mechanism)。注意机制是近年来比较有用的工具,最早用于自然语言处理领域,后来在计算机视觉领域也得到了广泛地应用。它可以通过在整个图像中学习一个权重矩阵,来聚焦重要区域,抑制非必要信息。
在小样本学习中,因为数据的缺少,所以数据的抽样分布和该类别的真实分布往往会有一定的偏差,所以可以通过注意力的方式对样本赋予权重,调整不同的样本在模型训练中的重要性。首先可以对同一类别的样本点进行均值的计算,求出类别的原型向量,再通过每个样本与原型向量的欧几里得距离求得每个样本的权重,从而辅助小样本学习网络的训练。
本小样本学习网络基于元学习的应用范式,基于元训练和元测试两个阶段。首先在元训练阶段,收集大量数据(如ImageNet数据集),将数据集按类别随机划分为源域和目标域数据。首先将源域数据划分为N-way-K-shot(一轮训练N个类别,每个类别K个数据)输入进初始化的模型,进行训练。得到的特征向量再和目标域数据一同输入选择性对抗迁移网络模块进行对抗训练,源域的数据经过Generator模块后通过分类,得到不同Discriminator模块的权重,再由Discriminator模块对源域和目标域的数据进行区分,计算出的损失通过反向传播的方式,更新整个小样本学习网络。在元测试阶段,将目标任务的数据输入进已经预训练好的小样本学习网络,得到小样本学习网络的预测分类结果。
本发明并不限于上文描述的实施方式。以上对具体实施方式的描述旨在描述和说明本发明的技术方案,上述的具体实施方式仅仅是示意性的,并不是限制性的。在不脱离本发明宗旨和权利要求所保护的范围情况下,本领域的普通技术人员在本发明的启示下还可做出很多形式的具体变换,这些均属于本发明的保护范围之内。
Claims (2)
1.一种基于选择性特征迁移的小样本学习网络,其特征在于,包括元学习模块、ResNet模块、选择性对抗迁移网络模块和自注意力模块;
元学习模块包括元训练阶段和元测试阶段,元训练阶段通过在与目标任务相近的任务上学习,训练得到能够作为目标任务初始化起点的预训练模型;元测试阶段是在预训练模型上训练目标任务;
ResNet模块采用了层级之间的恒等映射,通过残差学习的方式进行训练;
选择性对抗迁移网络模块由一个生成器网络和若干个判别器网络组成;生成器网络生成的数据通过一个classifier层得到关于每个判别器网络的权重向量,根据权重向量加权,再传递给后续的判别器网络,每个判别器网络判别收到的数据;数据分布相近的类别所对应的判别器网络的权重会更大,所以特征的迁移由对应的判别器网络进行,避免了负迁移的现象,实现了提升特征迁移的效果;
自注意力模块能够对同一类别的样本点进行均值的计算,求出类别的原型向量,再通过每个样本与原型向量的欧几里得距离求得每个样本的权重,从而辅助整个小样本学习网络的训练。
2.根据权利要求1所述一种基于选择性特征迁移的小样本学习网络,其特征在于,元训练阶段中收集数据得到数据集,将数据集按类别随机划分为源域和目标域数据;在一轮训练中,首先将源域数据划分为N个类别,每个类别中有K个数据,输入进初始化的小样本学习网络,进行训练;得到的特征向量再和目标域数据一同输入选择性对抗迁移网络模块进行对抗训练,源域的数据经过生成器后通过分类,得到不同判别器的权重,再由判别器对源域和目标域的数据进行区分,计算出的损失通过反向传播的方式,更新整个小样本学习网络;在元测试阶段,将目标任务的数据输入进已经预训练好的预训练模型,得到小样本学习网络的预测分类结果。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111400311.3A CN114048843A (zh) | 2021-11-19 | 2021-11-19 | 一种基于选择性特征迁移的小样本学习网络 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111400311.3A CN114048843A (zh) | 2021-11-19 | 2021-11-19 | 一种基于选择性特征迁移的小样本学习网络 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114048843A true CN114048843A (zh) | 2022-02-15 |
Family
ID=80211609
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111400311.3A Pending CN114048843A (zh) | 2021-11-19 | 2021-11-19 | 一种基于选择性特征迁移的小样本学习网络 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114048843A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115037641A (zh) * | 2022-06-01 | 2022-09-09 | 网络通信与安全紫金山实验室 | 基于小样本的网络流量检测方法、装置、电子设备及介质 |
CN117541555A (zh) * | 2023-11-16 | 2024-02-09 | 广州市公路实业发展有限公司 | 一种道路路面病害检测方法及系统 |
-
2021
- 2021-11-19 CN CN202111400311.3A patent/CN114048843A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115037641A (zh) * | 2022-06-01 | 2022-09-09 | 网络通信与安全紫金山实验室 | 基于小样本的网络流量检测方法、装置、电子设备及介质 |
CN115037641B (zh) * | 2022-06-01 | 2024-05-03 | 网络通信与安全紫金山实验室 | 基于小样本的网络流量检测方法、装置、电子设备及介质 |
CN117541555A (zh) * | 2023-11-16 | 2024-02-09 | 广州市公路实业发展有限公司 | 一种道路路面病害检测方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Patro et al. | U-cam: Visual explanation using uncertainty based class activation maps | |
CN109993100B (zh) | 基于深层特征聚类的人脸表情识别的实现方法 | |
Chen et al. | Learning linear regression via single-convolutional layer for visual object tracking | |
Wang et al. | Describe and attend to track: Learning natural language guided structural representation and visual attention for object tracking | |
Ajagbe et al. | Investigating the efficiency of deep learning models in bioinspired object detection | |
CN114048843A (zh) | 一种基于选择性特征迁移的小样本学习网络 | |
Wang et al. | Adversarial learning for zero-shot domain adaptation | |
Zhang et al. | Classification of canker on small datasets using improved deep convolutional generative adversarial networks | |
Cui et al. | WEDL-NIDS: Improving network intrusion detection using word embedding-based deep learning method | |
Slade et al. | An evolving ensemble model of multi-stream convolutional neural networks for human action recognition in still images | |
Shehu et al. | Lateralized approach for robustness against attacks in emotion categorization from images | |
Zhuang et al. | A handwritten Chinese character recognition based on convolutional neural network and median filtering | |
Aygun et al. | Exploiting convolution filter patterns for transfer learning | |
US20230419170A1 (en) | System and method for efficient machine learning | |
Reese et al. | LB-CNN: Convolutional neural network with latent binarization for large scale multi-class classification | |
Raj et al. | Object detection and recognition using small labeled datasets | |
Liang et al. | Facial feature extraction method based on shallow and deep fusion CNN | |
Kanungo | Analysis of Image Classification Deep Learning Algorithm | |
Gulshad et al. | Hierarchical explanations for video action recognition | |
Jayaram et al. | A brief study on rice diseases recognition and image classification: fusion deep belief network and S-particle swarm optimization algorithm | |
Shi et al. | Tracking topology structure adaptively with deep neural networks | |
Eghbali et al. | Deep Convolutional Neural Network (CNN) for Large-Scale Images Classification | |
Ukil et al. | Adv-resnet: Residual network with controlled adversarial regularization for effective classification of practical time series under training data scarcity problem | |
Liu et al. | Multimedia classification using bipolar relation graphs | |
Lanzetta | Machine learning, deep learning, and artificial intelligence |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |