CN114492581A - 基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 - Google Patents
基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 Download PDFInfo
- Publication number
- CN114492581A CN114492581A CN202111615640.XA CN202111615640A CN114492581A CN 114492581 A CN114492581 A CN 114492581A CN 202111615640 A CN202111615640 A CN 202111615640A CN 114492581 A CN114492581 A CN 114492581A
- Authority
- CN
- China
- Prior art keywords
- learning
- training
- network
- meta
- pictures
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于迁移学习和注意力机制元学习应用在小样本图片分类的方法,该方法从大规模的训练数据中学习到先验知识,在只使用少量标记训练数据情况下,可以帮助深度神经网络更快的收敛,同时降低网络过拟合的可能性。该方法采用DenseNet网络作为特征提取器,小样本分类任务的难点就是样本量少,本方法采用的特征提取器网络采用特征重用的方法,将有限的图片进行充分的利用。大规模数据训练为深度网络权值提供了良好的初始化,使元学习在较少的任务下能够快速的收敛,这些操作保持了训练后的深度网络权重不变,从而避免了灾难遗忘的问题。
Description
技术领域
本发明属于深度学习图片分类领域,特别是涉及了一种基于迁移学习和注意力机制元学习应用在小样本图片分类的方法。
背景技术
深度学习在很多领域都取得了很大的成就,比如在目标检测、图像分类、语义分割等方面得出的结果都可以超过人类,但它们通常需要很多的数据才能达到比较高的准确度,并且收集和注释大量的数据成本也是非常昂贵的。而人类可以通过一本书中的一幅插图概括出“狮子”的概念,那么让机器从少量样本中“概括”出某种物体的概念吸引了大量研究者的关注。从少量数据中学习是机器视觉面临的挑战,近年来,元学习在改善机器视觉的少样本学习中表现出了良好的性能。
元学习即“学会学习”,与传统的机器学习方法不同,传统的机器学习的方法是使用固定的学习算法从头开始解决给定的任务,元学习的目的是改进学习算法本身,在多个学习任务中获得经验,通常覆盖相关任务的分布,并使用这种经验来改进未来的学习性能。元学习是一种任务级学习方法,旨在通过学习多个任务积累经验,而基础学习器则侧重于对单个任务的数据分布进行建模。这方面的代表就是模型不确定的元学习(MAML),学习搜索最优初始化状态,以快速适应基础学习器的新任务。它的任务不可知特性使其有可能推广到少样本监督学习和无监督强化学习。然而这种方法存在着局限性,每个任务通常由一个低复杂度的基础学习器(如浅层神经网络)建模,以避免模型过拟合,从而无法使用更深入和更强大的网络架构。并且现有的元学习方法通常忽视了注意力机制的存在,注意力机制在认知和学习的过程中被证明是重要的。
随着深度网络研究的不断深入,一个新的问题出现了:随着关于输入或梯度的信息经过许多层,当它到达网络的终点或起点时,它可能消失或“洗掉”。近些年提出许多了解决这个问题的方法,例如:ResNets通过跨越连接从一层连接到下一层,但实际中许多层的贡献很小,可以在训练中随机丢弃。这使得ResNets的状态类似于展开的循环神经网络,但是它的参数数量却很大,因为每一层都有自己的权重。本文中使用DenseNet来作为特征提取器,DenseNet明确地区分了添加到网络中的信息和保留的信息。DenseNet层非常狭窄(例如,每层12个滤波器),只向网络的“集体知识”中添加一小组特征图,并保持剩余的特征图不变——最终的分类器基于网络中的所有特征图做出决策。DenseNet的另一大优势是其改进的信息流动和整个网络的梯度,这使它们易于训练。每一层都有直接访问从损失函数和原始输入信号的梯度,导致一个隐式的深度监督。
近年来,注意机制也被广泛应用在了计算机视觉系统、机器翻译中。神经网络中的注意力机制是在计算能力有限的情况下,将计算资源分配给更重要的任务,同时解决信息超载问题的一种资源分配方案。在神经网络学习中,一般而言模型的参数越多则模型的表达能力越强,模型所存储的信息量也越大,但这会带来信息过载的问题。那么通过引入注意力机制,在众多的输入信息中聚焦于对当前任务更为关键的信息,降低对其他信息的关注度,甚至过滤掉无关信息,就可以解决信息过载问题,并提高任务处理的效率和准确性。那么什么是注意力机制呢?当我们在看一个场景的时候,我们看到的一定是某个场景的某一个地方,当我们的视觉在移动时,注意力随着目光的移动也在移动。也就是说,当人在注意到某个场景时,该场景内每一空间上的注意力分布是不一致的。因此,可以借鉴人脑的注意力机制,只选择一些关键的信息输入进行处理,来提高神经网络的效率。
发明内容
对于以上所提出的问题,针对训练样本过少的问题,本发明使用DenseNet网络来进行预训练,来提取特征,预训练结束后冻结DenseNet网络参数,靠大量数据训练出来的网络参数,能保证其具有良好的初始化,并且DenseNet会采用多次特征重用,提高较少特征的利用率。使用注意力机制来对特征提取器提取出来的特征进行通道加权,校正后的特征可保留有价值的特征,剔除没价值的特征。在元学习阶段用位移和偏差来进行参数的元学习,即减少了网络的参数,又避免了灾难性遗忘的问题。
一种基于迁移学习和注意力机制元学习的应用在小样本图片分类的方法,包括以下步骤:
(1)获取数据:读取数据集中预训练中的图片,其中图片按任务划分,不同任务的图片在不同的文件夹中,按照任务分布进行图片的读取;
(2)迁移学习和注意力机制元学习网络框架的搭建:包括固定的特征提取器和由于预训练与元学习阶段分类任务数量的不同,采用的不同的类别输出层;
(2.1)预训练阶段的模型框架包括:用DenseNet网络作为特征提取器,来提取输入图片的特征,后面接一个平均池化层,来对将经过DenseNet网络提取出来的特征进行降维,去除冗余信息,对池化后的特征进行展平,后面接一个全连接层,最后是一个根据分类任务确定的类别输出层;
(2.2)元学习阶段的模型框架包括:用DenseNet网络作为特征提取器,来提取输入图片的特征,将提取出来的图片特征输入到注意力机制模块中,注意力机制使用的是通道注意力,将每个通道的特征图进行全局平均池化,得到注意力加权值,再将这个加权值应用于原来的特征图中,对每个通道的数值进行加权。加权后的特征进行展平,后接一个全连接层,最后是一个与训练阶段不同的类别输出层;
(3)对步骤(2)搭建的预训练阶段的模型框架和元学习阶段的模型框架进行训练;
(3.1)初始化预训练网络参数,把训练数据集输入到预训练的网络框架中去,来对预训练的网络框架参数进行优化,学习卷积层中的网络参数权值W和偏差b,通过交叉熵损失函数来减小同一任务之间的特征分布,最后用softmax计算样本所属概率最大的类别即为图片预训练阶段的预测类别;
(3.2)预训练网络的参数更新;
(3.3)重复步骤(3.1)、(3.2)直至网络迭代次数达到预设迭代次数,取最好精度的迭代次数的网络参数Θ(W,b);
(3.4)预训练结束,网络参数Θ(W,b)固定不再更新;
(3.5)初始化元学习网络参数,将测试数据输入到元学习的网络框架中去,此时网络中特征提取器卷积层所使用的权值W和偏差b是预训练阶段迭代精度最好的参数,并引入两个新的参数:缩放和平移;
(3.6)元学习网络参数更新;
(3.7)重复步骤(3.5)、(3.6)直至网络迭代次数达到预设迭代次数,取最好精度的迭代次数的网络参数;
(3.8)元学习阶段结束后,用验证数据集对网络模型进行验证,网络最终输出的分类精度为最后的模型评估精度。
进一步的,所述步骤(1)包括以下步骤:
(1.1)为了提高分类的难度,数据集中的图片尺寸为84×84,将图片尺寸resize到40×40,然后随机切割36×36尺寸大小的图片;每张图片转换为RGB三通道,转换为c×h×w的三维矩阵,h和w分别为图像的高和宽,c为通道数;
(1.2)训练图片转换成nS×c×h×w的四维矩阵数据,nS表示任务T训练样本数;同一任务中随机抽取没有训练的图片作为验证数据转换为nT×c×h×w的四维矩阵数据,nT表示同一任务中用来验证的数据样本数;
(1.3)图片类别用one-hot编码,图片共有N类,则第一类的标签表示[1,0,0,...,0]1×N,第二类的标签表示为[0,1,0,...,0]1×N,…,第N类图片的标签表示为[0,0,0,...,1]1×N。
进一步的,所述步骤(2)包括以下步骤:
特征提取器的训练:用训练数据集对DenseNet网络进行预训练,获得DenseNet密集连接网络模型参数,其中密集连接网络输入为batch_size×c×h×w的矩阵,batch_size的大小取决于计算机内存;
所述预训练阶段:密集连接网络用来提取特征,提取出来特征有342通道,后面进行展平,为了降低展平后的维度,对每个通道上的特征进行全局池化,得到一个342×1×1的向量,展平后接一层全连接层,固定全连接层的神经元个数为600,最后一层是输出层,用softmax函数计算图片所属每种类别的概率;
所述元学习阶段:密集连接网络用来特征提取,提取出来的特征有342通道,把提取出来的特征送入到注意力机制模块中,特征图的形状为(128,10,10,342),其中128,10,10,342分别是特征图的batch-size、宽度、高度和通道数量;后面进行展平,为了降低展平后的维度,对每个通道上的特征进行全局池化,得到一个342×1×1的向量,展平后接一层全连接层,固定全连接层的神经元个数为600,最后一层是输出层,用softmax函数计算图片所属每种类别的概率。
进一步的,所述步骤(3.2)包括以下步骤:
初始化网络参数,预训练部分:将训练数据集图片输入到DenseNet密集连接网络中去,在DenseNet密集连接网络中,会有特征重用的阶段,因此,网络框架第l层收所有之前层的特征图,x0…xl-1,输入:
xl=Hl([x0,x1,...,xl-1]) (1)
其中[x0,x1,...,xl-1]表示在第0层至l-1层中生成的所有特征图的级联,Hl表示将级联后形成的张量;
为了便于实现,多个输入连接成一个张量;在这个阶段,先不考虑来自其他数据集的数据或领域自适应,并在现成的少样本学习基准数据上进行预训练;具体来说,对于一个特定的少样本数据集,合并所有类数据D进行预训练。
首先随机初始化一个特征提取器Θ和辅助分类器θ,然后通过梯度下降对其进行优化,
其中L定义为经验损失,α为学习率,设置学习率α为0.01,
在此阶段,学习特征提取器Θ;它学习到的参数将在接下来的元训练和元测试阶段被冻结;学习到的辅助分类器θ将被丢弃,因为后续的少样本任务包含不同的分类目标。
进一步的,所述步骤(3.5)中引入的两个新的参数:缩放和平移,记为和对特征提取器中的权重W进行加权缩放,对偏差b进行加权平移;将特征提取器提取出来的图片特征输入到注意力机制层中,对提取出来的特征进行通道加权,然后将加权后的特征展平、全连接,最后用softmax计算样本所属概率最大的类别即为图片元学习阶段的预测类别。
进一步的,所述步骤(3.6)包括以下步骤:
对于给定的任务T,利用任务T中训练数据的损失通过梯度下降优化当前的基础学习器,也即是分类器θ′:
其中⊙表示元素对应相乘。
与现有技术相比,本发明有益效果是:
本发明用元学习的方法来处理小样本分类的问题,小样本分类本身样本量较少,可以利用的特征也会相对来说会比较少,用DenseNet密集连接网络作为特征提取器,网络框架中的特征重用会充分利用每一步卷积之后提取出来的特征图,将有限的特征进行充分的利用。同时在预训练阶段,通过大规模数据图片来训练网络参数,会得到较好的网络参数,得到较好的先验知识。在元学习阶段引入缩放和平移的参数,通过训练集对这两个参数进行更新,缩小了网络训练的参数量,同时充分利用了先验知识,不会造成遗忘灾难的问题。卷积提取特征是非常重要的,注意力机制的引入可对特征进行校正,校正后的特征可保留有价值的特征,剔除没价值的特征。提高模型的分类精度,提升小样本分类特征的利用率,以便更准确地在新任务的情况下模型对图片所属类别的预测精度。
附图说明
图1是本发明所述迁移学习和注意力机制元学习网络结构图。
图2是本发明所述在miniImageNet数据集的交叉验证实验精度上升曲线。
具体实施方式
下面结合附图及具体实验,对本发明做进一步说明。
小样本分类样本量少是普遍公认的问题,元学习其具有的特性可以很契合的解决这个问题。具体来说就是用于训练的样本量较少,使用常规的模型会陷入过拟合的问题。由于本发明使用的是大型网络框架,使用全部参数微调的方法显然是不现实的。根本动机就是:用元学习的方法减少模型框架更新参数的数量来进行小样本分类。
本发明的一种基于基于迁移学习和注意力机制元学习来解决小样本图片分类方法包括以下步骤:
一种基于迁移学习和注意力机制元学习的应用在小样本图片分类的方法,包括以下步骤:
(1)获取数据:读取数据集中预训练中的图片,其中图片按任务划分,不同任务的图片在不同的文件夹中,按照任务分布进行图片的读取;
(1.1)为了提高分类的难度,数据集中的图片尺寸为84×84,将图片尺寸resize到40×40,然后随机切割36×36尺寸大小的图片;每张图片转换为RGB三通道,转换为c×h×w的三维矩阵,h和w分别为图像的高和宽,c为通道数;
(1.2)训练图片转换成nS×c×h×w的四维矩阵数据,nS表示任务T训练样本数;同一任务中随机抽取没有训练的图片作为验证数据转换为nT×c×h×w的四维矩阵数据,nT表示同一任务中用来验证的数据样本数;
(1.3)图片类别用one-hot编码,图片共有N类,则第一类的标签表示[1,0,0,...,0]1×N,第二类的标签表示为[0,1,0,...,0]1×N,…,第N类图片的标签表示为[0,0,0,...,1]1×N。
(2)迁移学习和注意力机制元学习网络框架的搭建:包括固定的特征提取器和由于预训练与元学习阶段分类任务数量的不同,采用的不同的类别输出层;
特征提取器的训练:用训练数据集对DenseNet网络进行预训练,获得DenseNet密集连接网络模型参数,其中密集连接网络输入为batch_size×c×h×w的矩阵,batch_size的大小取决于计算机内存;
所述预训练阶段:密集连接网络用来提取特征,提取出来特征有342通道,后面进行展平,为了降低展平后的维度,对每个通道上的特征进行全局池化,得到一个342×1×1的向量,展平后接一层全连接层,固定全连接层的神经元个数为600,最后一层是输出层,用softmax函数计算图片所属每种类别的概率;
所述元学习阶段:密集连接网络用来特征提取,提取出来的特征有342通道,把提取出来的特征送入到注意力机制模块中,特征图的形状为(128,10,10,342),其中128,10,10,342分别是特征图的batch-size、宽度、高度和通道数量;对注意力机制的设置可以理解为:使用一些网络去计算一个权重,把这个权重与特征图进行运算,对这个特征图进行改变,得到加强注意力后的特征图。卷积提取特征是非常重要的,注意力机制可对特征进行校正,校正后的特征可保留有价值的特征,剔除没价值的特征。后面进行展平,为了降低展平后的维度,对每个通道上的特征进行全局池化,得到一个342×1×1的向量,展平后接一层全连接层,固定全连接层的神经元个数为600,最后一层是输出层,用softmax函数计算图片所属每种类别的概率。
(3)对步骤(2)搭建的预训练阶段的模型框架和元学习阶段的模型框架进行训练;
(3.1)初始化预训练网络参数,把训练数据集输入到预训练的网络框架中去,来对预训练的网络框架参数进行优化,学习卷积层中的网络参数权值W和偏差b,通过交叉熵损失函数来减小同一任务之间的特征分布,最后用softmax计算样本所属概率最大的类别即为图片预训练阶段的预测类别;
(3.2)预训练网络的参数更新;
初始化网络参数,预训练部分:将训练数据集图片输入到DenseNet密集连接网络中去,在DenseNet密集连接网络中,会有特征重用的阶段,因此,网络框架第l层收所有之前层的特征图,x0…xl-1,输入:
xl=Hl([x0,x1,...,xl-1]) (1)
其中[x0,x1,...,xl-1]表示在第0层至l-1层中生成的所有特征图的级联,Hl表示将级联后形成的张量;
为了便于实现,多个输入连接成一个张量;在这个阶段,先不考虑来自其他数据集的数据或领域自适应,并在现成的少样本学习基准数据上进行预训练;具体来说,对于一个特定的少样本数据集,合并所有类数据D进行预训练。例如,对于miniImageNet,在数据集D的训练分割中总共有64个类,每个类包含600个样本,用于预先训练64类分类器。
首先随机初始化一个特征提取器Θ和辅助分类器θ,然后通过梯度下降对其进行优化,
其中L定义为经验损失,α为学习率,设置学习率α为0.01,
在此阶段,学习特征提取器Θ;它学习到的参数将在接下来的元训练和元测试阶段被冻结;学习到的辅助分类器θ将被丢弃,因为后续的少样本任务包含不同的分类目标。
(3.3)重复步骤(3.1)、(3.2)直至网络迭代次数达到预设迭代次数,取最好精度的迭代次数的网络参数Θ(W,b);
(3.4)预训练结束,网络参数Θ(W,b)固定不再更新;
(3.5)初始化元学习网络参数,将测试数据输入到元学习的网络框架中去,此时网络中特征提取器卷积层所使用的权值W和偏差b是预训练阶段迭代精度最好的参数,并引入两个新的参数:缩放和平移;
所述两个新的参数:缩放和平移,记为和对特征提取器中的权重W进行加权缩放,对偏差b进行加权平移;将特征提取器提取出来的图片特征输入到注意力机制层中,对提取出来的特征进行通道加权,然后将加权后的特征展平、全连接,最后用softmax计算样本所属概率最大的类别即为图片元学习阶段的预测类别。
(3.6)元学习网络参数更新;
所述步骤(3.6)包括以下步骤:
对于给定的任务T,利用任务T中训练数据的损失通过梯度下降优化当前的基础学习器,也即是分类器θ′:
这与公式(2)不同,这里没有更新特征提取器Θ,需要注意的是,这里的分类器与前一阶段,即公式(2)中的大规模的辅助分类器θ不同;这个分类器少于大规模分类器,在一个新的少样本的场景中对样本图片进行分类;对应于只在当前任务中工作的分类器,为前一个任务优化的初始化;
其中⊙表示元素对应相乘。
(3.7)重复步骤(3.5)、(3.6)直至网络迭代次数达到预设迭代次数,取最好精度的迭代次数的网络参数;
(3.8)元学习阶段结束后,用验证数据集对网络模型进行验证,网络最终输出的分类精度为最后的模型评估精度。
本发明可通过以下实验进一步说明:
为了验证本发明的有效性,分别在Omniglo,miniImageNet,FC100数据集上做了实验。
为了体现元学习的多任务性,将数据集分为训练集,验证集和测试集。
由于Omniglot是一个比MiniImagenet简单得多的数据集,现有的元学习方法可以很容易地在Omniglot上生成的大多数测试任务上达到95%以上的准确率,所以我们只在Omniglot上测试TML方法。与在Miniimagenet上的实验相同,我们也在20万个随机生成的任务上训练元学习者并设置学习率为0.001。实验结果如表1所示。可以看出,所提出的方法TML在少镜头图像分类任务中,达到了比较先进的性能。
miniimagenet由Vinyalset提出,用于少样本学习评估。由于使用了ImageNet图像,它的复杂性很高,但与在完整的ImageNet数据集上运行相比,它需要更少的资源和基础设施。总共有100个类别,每个类有600张84×84彩色图片样本。将这100个类分为64个、16个和20个类,分别进行元训练、元验证和元测试的抽样任务,并进行相关工作。
Fewshot-CIFAR100(FC100)是基于目前流行的对象分类数据集CIFAR100。它提供了一个更具挑战性的场景,具有较低的图像分辨率和更具有挑战性的元训练/测试分割(根据对象超类进行分离)。它包含100个对象类,每个类有600个32×32的样本图像。这100个类属于20个超类。元训练数据来自于属于12个超类的60个类。元验证和元测试集包含20个类,分别属于4个超类。这些划分符合超类,从而最小化训练、验证和测试任务之间的信息重叠。
所有训练数据点训练一个大规模的深度神经网络模型,并在100次迭代后停止训练。我们使用与相关工作相同的任务抽样方法。具体来说,1)考虑5类分类;2)对5类1样本或5样本的任务进行抽样,以包含1个或5个样本用于训练,15个(统一)样本用于测试。总共抽取8k个任务进行元训练,分别抽取了600个随机任务进行元验证和元测试。
表1Omniglot数据集实验精度
表2:FC100数据集实验精度
表3:miniImageNet数据集实验精度
表4:交叉实验结果
本发明的方法从大规模的训练数据中学习到先验知识,在只使用少量标记训练数据情况下,可以帮助深度神经网络更快的收敛,同时降低网络过拟合的可能性。该方法采用DenseNet网络作为特征提取器,小样本分类任务的难点就是样本量少,本方法采用的特征提取器网络采用特征重用的方法,将有限的图片进行充分的利用。在预训练阶段采用密集网络对大规模的数据进行训练,来训练网络的权重和偏差,对特征提取器最后提取出来的特征进行展平,后接全连接层和分类层。此时进行的是64分类,由于预训练阶段的数据量比较多,所以训练出来的网络参数也会比较好。在预训练结束后,将训练好的权重和偏差进行固定,修改后面的分类器,以便进行下一步的元学习阶段。在元学习阶段,利用预训练阶段学习到的先验知识,对网络中的权重进行缩放和偏差进行平移来进行元学习,只对这两个参数进行更新,不对权重和偏差进行更新。大规模数据训练为深度网络权值提供了良好的初始化,使元学习在较少的任务下能够快速的收敛,这些操作把持了训练后的深度网络权重不变,从而避免了灾难遗忘的问题,从而提高图像数据集的分类准确度。
应该理解,本发明并不局限于上述具体实例,在本发明开的基础上,凡是熟悉本领域的技术人员在不违背本发明精神的前提下还可做出等同变形或替换,这些等同的变型或替换均包含在本申请权利要求所限定的范围。
Claims (7)
1.一种基于迁移学习和注意力机制元学习的应用在小样本图片分类的方法,其特征在于,包括以下步骤:
(1)获取数据:读取数据集中预训练中的图片,其中图片按任务划分,不同任务的图片在不同的文件夹中,按照任务分布进行图片的读取;
(2)迁移学习和注意力机制元学习网络框架的搭建:包括固定的特征提取器和由于预训练与元学习阶段分类任务数量的不同,采用的不同的类别输出层;
(2.1)预训练阶段的模型框架包括:用DenseNet网络作为特征提取器,来提取输入图片的特征,后面接一个平均池化层,来对将经过DenseNet网络提取出来的特征进行降维,去除冗余信息,对池化后的特征进行展平,后面接一个全连接层,最后是一个根据分类任务确定的类别输出层;
(2.2)元学习阶段的模型框架包括:用DenseNet网络作为特征提取器,来提取输入图片的特征,将提取出来的图片特征输入到注意力机制模块中,注意力机制使用的是通道注意力,将每个通道的特征图进行全局平均池化,得到注意力加权值,再将这个加权值应用于原来的特征图中,对每个通道的数值进行加权。加权后的特征进行展平,后接一个全连接层,最后是一个与训练阶段不同的类别输出层;
(3)对步骤(2)搭建的预训练阶段的模型框架和元学习阶段的模型框架进行训练;
(3.1)初始化预训练网络参数,把训练数据集输入到预训练的网络框架中去,来对预训练的网络框架参数进行优化,学习卷积层中的网络参数权值W和偏差b,通过交叉熵损失函数来减小同一任务之间的特征分布,最后用softmax计算样本所属概率最大的类别即为图片预训练阶段的预测类别;
(3.2)预训练网络的参数更新;
(3.3)重复步骤(3.1)、(3.2)直至网络迭代次数达到预设迭代次数,取最好精度的迭代次数的网络参数Θ(W,b);
(3.4)预训练结束,网络参数Θ(W,b)固定不再更新;
(3.5)初始化元学习网络参数,将测试数据输入到元学习的网络框架中去,此时网络中特征提取器卷积层所使用的权值W和偏差b是预训练阶段迭代精度最好的参数,并引入两个新的参数:缩放和平移;
(3.6)元学习网络参数更新;
(3.7)重复步骤(3.5)、(3.6)直至网络迭代次数达到预设迭代次数,取最好精度的迭代次数的网络参数;
(3.8)元学习阶段结束后,用验证数据集对网络模型进行验证,网络最终输出的分类精度为最后的模型评估精度。
2.根据权利要求1所述一种基于迁移学习和注意力机制元学习应用于少样本图片分类的方法,其特征在于,所述步骤(1)包括以下步骤:
(1.1)为了提高分类的难度,数据集中的图片尺寸为84×84,将图片尺寸resize到40×40,然后随机切割36×36尺寸大小的图片;每张图片转换为RGB三通道,转换为c×h×w的三维矩阵,h和w分别为图像的高和宽,c为通道数;
(1.2)训练图片转换成nS×c×h×w的四维矩阵数据,nS表示任务T训练样本数;同一任务中随机抽取没有训练的图片作为验证数据转换为nT×c×h×w的四维矩阵数据,nT表示同一任务中用来验证的数据样本数;
(1.3)图片类别用one-hot编码,图片共有N类,则第一类的标签表示[1,0,0,...,0]1×N,第二类的标签表示为[0,1,0,...,0]1×N,…,第N类图片的标签表示为[0,0,0,...,1]1×N。
3.根据权利要求1所述一种基于迁移学习和注意力机制元学习应用于少样本图片分类的方法,其特征在于,所述步骤(2)包括以下步骤:
特征提取器的训练:用训练数据集对DenseNet网络进行预训练,获得DenseNet密集连接网络模型参数,其中密集连接网络输入为batch_size×c×h×w的矩阵,batch_size的大小取决于计算机内存;
所述预训练阶段:密集连接网络用来提取特征,提取出来特征有342通道,后面进行展平,为了降低展平后的维度,对每个通道上的特征进行全局池化,得到一个342×1×1的向量,展平后接一层全连接层,固定全连接层的神经元个数为600,最后一层是输出层,用softmax函数计算图片所属每种类别的概率;
所述元学习阶段:密集连接网络用来特征提取,提取出来的特征有342通道,把提取出来的特征送入到注意力机制模块中,特征图的形状为(128,10,10,342),其中128,10,10,342分别是特征图的batch-size、宽度、高度和通道数量;后面进行展平,为了降低展平后的维度,对每个通道上的特征进行全局池化,得到一个342×1×1的向量,展平后接一层全连接层,固定全连接层的神经元个数为600,最后一层是输出层,用softmax函数计算图片所属每种类别的概率。
4.根据权利要求1所述一种基于迁移学习和注意力机制元学习应用于少样本图片分类的方法,其特征在于,所述步骤(3.2)包括以下步骤:
初始化网络参数,预训练部分:将训练数据集图片输入到DenseNet密集连接网络中去,在DenseNet密集连接网络中,会有特征重用的阶段,因此,网络框架第l层收所有之前层的特征图,x0…xl-1,输入:
xl=Hl([x0,x1,...,xl-1]) (1)
其中[x0,x1,...,xl-1]表示在第0层至l-1层中生成的所有特征图的级联,Hl表示将级联后形成的张量;
为了便于实现,多个输入连接成一个张量;在这个阶段,先不考虑来自其他数据集的数据或领域自适应,并在现成的少样本学习基准数据上进行预训练;具体来说,对于一个特定的少样本数据集,合并所有类数据D进行预训练。
首先随机初始化一个特征提取器Θ和辅助分类器θ,然后通过梯度下降对其进行优化,
其中L定义为经验损失,α为学习率,设置学习率α为0.01,
在此阶段,学习特征提取器Θ;它学习到的参数将在接下来的元训练和元测试阶段被冻结;学习到的辅助分类器θ将被丢弃,因为后续的少样本任务包含不同的分类目标。
6.根据权利要求5所述一种基于迁移学习和注意力机制元学习应用于少样本图片分类的方法,其特征在于,所述步骤(3.6)包括以下步骤:
对于给定的任务T,利用任务T中训练数据的损失通过梯度下降优化当前的基础学习器,也即是分类器θ′:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111615640.XA CN114492581A (zh) | 2021-12-27 | 2021-12-27 | 基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111615640.XA CN114492581A (zh) | 2021-12-27 | 2021-12-27 | 基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114492581A true CN114492581A (zh) | 2022-05-13 |
Family
ID=81496542
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111615640.XA Pending CN114492581A (zh) | 2021-12-27 | 2021-12-27 | 基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114492581A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116202611A (zh) * | 2023-05-06 | 2023-06-02 | 中国海洋大学 | 一种基于元学习的小样本声速剖面反演方法 |
WO2024082374A1 (zh) * | 2022-10-19 | 2024-04-25 | 电子科技大学长三角研究院(衢州) | 一种基于层级化元迁移的小样本雷达目标识别方法 |
-
2021
- 2021-12-27 CN CN202111615640.XA patent/CN114492581A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2024082374A1 (zh) * | 2022-10-19 | 2024-04-25 | 电子科技大学长三角研究院(衢州) | 一种基于层级化元迁移的小样本雷达目标识别方法 |
CN116202611A (zh) * | 2023-05-06 | 2023-06-02 | 中国海洋大学 | 一种基于元学习的小样本声速剖面反演方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112308158B (zh) | 一种基于部分特征对齐的多源领域自适应模型及方法 | |
US11256960B2 (en) | Panoptic segmentation | |
Bartz et al. | STN-OCR: A single neural network for text detection and text recognition | |
CN109711426B (zh) | 一种基于gan和迁移学习的病理图片分类装置及方法 | |
Liu et al. | Multi-objective convolutional learning for face labeling | |
Donahue et al. | Decaf: A deep convolutional activation feature for generic visual recognition | |
CN109993100B (zh) | 基于深层特征聚类的人脸表情识别的实现方法 | |
CN111444881A (zh) | 伪造人脸视频检测方法和装置 | |
CN107944410B (zh) | 一种基于卷积神经网络的跨领域面部特征解析方法 | |
KR20190138238A (ko) | 딥 블라인드 전의 학습 | |
CN110188827B (zh) | 一种基于卷积神经网络和递归自动编码器模型的场景识别方法 | |
CN113674334B (zh) | 基于深度自注意力网络和局部特征编码的纹理识别方法 | |
CN111160350A (zh) | 人像分割方法、模型训练方法、装置、介质及电子设备 | |
CN114492581A (zh) | 基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 | |
CN116503676B (zh) | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 | |
Chen et al. | Automated design of neural network architectures with reinforcement learning for detection of global manipulations | |
CN113673482B (zh) | 基于动态标签分配的细胞抗核抗体荧光识别方法及系统 | |
CN112232395B (zh) | 一种基于联合训练生成对抗网络的半监督图像分类方法 | |
CN113159067A (zh) | 一种基于多粒度局部特征软关联聚合的细粒度图像辨识方法及装置 | |
CN109492610B (zh) | 一种行人重识别方法、装置及可读存储介质 | |
CN114693624A (zh) | 一种图像检测方法、装置、设备及可读存储介质 | |
CN109508640A (zh) | 一种人群情感分析方法、装置和存储介质 | |
Xu et al. | Graphical modeling for multi-source domain adaptation | |
Nalini et al. | Comparative analysis of deep network models through transfer learning | |
Bose et al. | Light Weight Structure Texture Feature Analysis for Character Recognition Using Progressive Stochastic Learning Algorithm |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |