CN113378937B - 一种基于自监督增强的小样本图像分类方法及系统 - Google Patents
一种基于自监督增强的小样本图像分类方法及系统 Download PDFInfo
- Publication number
- CN113378937B CN113378937B CN202110657337.XA CN202110657337A CN113378937B CN 113378937 B CN113378937 B CN 113378937B CN 202110657337 A CN202110657337 A CN 202110657337A CN 113378937 B CN113378937 B CN 113378937B
- Authority
- CN
- China
- Prior art keywords
- small sample
- self
- supervision
- sample
- network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 52
- 238000012549 training Methods 0.000 claims abstract description 16
- 230000004927 fusion Effects 0.000 claims description 23
- 238000013528 artificial neural network Methods 0.000 claims description 20
- 238000005070 sampling Methods 0.000 claims description 16
- 238000010586 diagram Methods 0.000 claims description 10
- 239000011159 matrix material Substances 0.000 claims description 10
- 238000009826 distribution Methods 0.000 claims description 8
- 238000004364 calculation method Methods 0.000 claims description 7
- 230000009466 transformation Effects 0.000 claims description 4
- 238000012512 characterization method Methods 0.000 abstract description 7
- 238000002679 ablation Methods 0.000 abstract description 2
- 238000003860 storage Methods 0.000 description 14
- 238000004590 computer program Methods 0.000 description 10
- 238000012545 processing Methods 0.000 description 6
- 230000006870 function Effects 0.000 description 5
- 238000004088 simulation Methods 0.000 description 5
- 238000013461 design Methods 0.000 description 4
- 238000013135 deep learning Methods 0.000 description 3
- 239000000284 extract Substances 0.000 description 3
- 238000002372 labelling Methods 0.000 description 3
- 238000004519 manufacturing process Methods 0.000 description 3
- 241000282326 Felis catus Species 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 2
- 238000006243 chemical reaction Methods 0.000 description 2
- 230000002708 enhancing effect Effects 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 230000004931 aggregating effect Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000008447 perception Effects 0.000 description 1
- 230000008569 process Effects 0.000 description 1
- 230000035755 proliferation Effects 0.000 description 1
- 230000001737 promoting effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/253—Fusion techniques of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于自监督增强的小样本图像分类方法及系统,自监督学习和小样本学习都是为了缓解模型对标签数据的依赖,本发明在基于图神经网络的小样本学习方法的基础上结合自监督学习提出了一种自监督增强的小样本学习方法,本发明设计一个抠图位置预测的自监督学习任务,将每一个小样本分类任务的所有样本进行随机抠图,在提取样本特征后经过一个全连接层预测每一个样本被抠除图像块的位置。本发明将抠图位置预测任务和小样本分类任务联合训练以增强模型提取有效表征的能力,从而改善模型的分类结果,并通过在miniImageNet上的对比和消融实验证明了本发明的有效性。
Description
技术领域
本发明属于计算机视觉技术领域,具体涉及一种基于自监督增强的小样本图像分类方法及系统。
背景技术
近些年来,随着深度学习算法的不断发展,人工智能在多个领域得到了广泛的应用,例如医疗、交通、工业制造等。人工智能繁荣发展的背后是神经网络的不断加深以及所需数据量的不断增加,从而导致了数据采集与人工标注的成本不断上升。不仅如此,现有的深度神经网络模型往往泛化性很差,比如用大量的猫狗图片训练了一个良好的猫狗分类器,但是如果想将其用于鸟的识别就又需要大量的鸟的图片来训练。这时就希望模型可以减少对数据的依赖,像人类一样可以进行快速的学习,那么将会大大减少数据的人工标注成本,基于此小样本学习渐渐得到了许多研究者的关注。为了降低上述成本,学术界的许多研究者甚至是工业界将目光聚焦到小样本学习上,希望模型具有从少量数据中快速泛化的能力。
小样本学习的目的是让模型仅仅通过少量的样本就可以识别新的类,所以模型越是能够学习到通用的知识、越是能够有效地提取特征,其小样本学习能力往往会更强。现有的小样本模型大多设计为参数量较少,网络较浅的结构,这是因为越深的网络越难对新的类别进行快速泛化。然而在分类任务上,精度最高的模型都是在大规模数据集上训练的深层的网络,这是深度学习的弊端,要想使网络拥有出色的提取特征的能力,大规模的神经网络和大规模的标注数据集是缺一不可的。所以现有基于深度学习的小样本分类模型的一个瓶颈是怎样训练有限规模的网络提取尽可能富含语义的特征。这时候我们就想到了自监督学习,自监督学习不需要样本标签,其学习方式是从人为构建的辅助任务中进行训练,希望通过在辅助任务上的训练使得模型能够学到泛化性强的表示以便于用于后续的下游任务。从某种程度来说,小样本学习和自监督学习都是为了减少对标签的依赖,增强模型提取有效表征的能力,只不过前者的方法是在一定量的带标签数据上进行小样本学习任务,后者则彻底摆脱了标签从数据本身构造任务进行学习。本章的工作旨在进一步地增强模型的表征能力,所以本发明在小样本分类任务的基础上额外设计了一种自监督学习任务,用这个任务来促使模型学习到更有效、更鲁棒的表示。
发明内容
本发明所要解决的技术问题在于针对上述现有技术中的不足,提供一种基于自监督增强的小样本图像分类方法及系统,通过构建自监督学习任务来增强模型提取有效表征的能力以提高分类的准确率。
本发明采用以下技术方案:
一种基于自监督增强的小样本图像分类方法,包括以下步骤:
S1、采样小样本学习任务T;
S2、对步骤S1得到的小样本学习任务T中的每个样本进行随机抠图,得到自监督小样本学习任务
S3、将步骤S1得到的小样本学习任务T中的所有样本送入构建的嵌入网络Femb中,得到每一个样本xi的特征图;然后构建全连接图GT;
S4、将步骤S3构建的全连接图GT输入构建的图神经网络GNN中,得到每一层的边特征
S5、构建边特征融合网络Ffus,对步骤S4得到的L层的边特征进行级联,输入边特征融合网络Ffus中得到最终边/>用最终边表示预测查询样本xi∈Q的类别,Q表示查询集,并根据预测结果得到小样本分类损失Lsu;
S6、将步骤S2得到的自监督小样本学习任务送入步骤S3中构建的嵌入网络Femb中学习嵌入表示,得到每一个样本/>的特征图;
S7、将步骤S6得到的每一个样本的特征图输入构建的抠图位置预测网络Fc中,得到样本/>被抠掉图像块的位置,并根据/>计算得到一个自监督损失Lse;
S8、根据步骤S5得到的损失Lsu和步骤S7得到的自监督损失Lse联合训练模型进行小样本分类以及抠图位置预测任务,实现增益小样本分类。
具体的,步骤S2中,自监督小样本学习任务为:
其中,表示经过随机抠图后的xi,/>表示xi被抠除图像块的位置,N为支持样本的类别数,K为支持样本每一类包含的样本,r为查询样本的个数。
具体的,步骤S3具体为:
S301、构建嵌入网络Femb,嵌入网络Femb的输入为每次从步骤S1中采样的B个小样本学习任务T,B为每批次的大小,输出为任务T中每个样本xi的特征图
S302、将步骤S301得到的特征图作为全连接图GT的初始结点特征然后进行边特征/>的初始化。
进一步的,步骤S302中,边特征的初始化具体为:
其中,yi和yj分别表示结点vi和vj的类别标签。
具体的,步骤S4具体为:
S401、输入步骤S3中得到的全连接图GT到图神经网络GNN进行更新,对于图神经网络GNN中的每一层,根据边特征更新结点特征;
S402、根据结点特征来新边特征,计算相邻结点特征和/>之间的关系矩阵Rij,然后将Rij送入一个边特征转换网络/>中再经过sigmoid操作来得到更新后的边特征/>
进一步的,步骤S402中,更新后的边特征具体为:
其中,表示网络中可学习的参数。
具体的,步骤S5中,小样本分类损失Lsu为:
其中,Lce表示交叉熵损失,xi表示查询集中的样本,yj为查询样本xi的标签,为查询结点vi的类别概率分布。
具体的,步骤S7中,自监督损失Lse为:
其中,表示/>被抠掉的位置标签,Lce表示交叉熵损失,/>为预测位置的概率分布。
具体的,步骤S8中,小样本分类损失Lsu和自监督损失Lse将小样本分类任务与抠图位置预测任务联合训练的总体损失L为:
L=αLsu+βLse
其中,α为0.8,β为0.2。
本发明的另一技术方案是,一种基于自监督增强的小样本图像分类系统,包括:
采样模块,采样小样本学习任务T;
抠图模块,对采样模块得到的小样本学习任务T中的每个样本进行随机抠图,得到自监督小样本学习任务
嵌入模块,将采样模块得到的小样本学习任务T中的所有样本送入构建的嵌入网络Femb中,得到每一个样本xi的特征图;然后构建全连接图GT;
特征模块,将嵌入模块构建的全连接图GT输入构建的图神经网络GNN中,得到每一层的边特征
融合模块,构建边特征融合网络Ffus,将特征模块得到的L层的边特征进行级联,输入边特征融合网络Ffus中得到最终边/>用最终边表示预测查询样本xi∈Q的类别,Q表示查询集,并根据预测结果来得到一个小样本分类损失Lsu;
学习模块,将抠图模块得到的自监督小样本学习任务送入嵌入模块构建的嵌入网络Femb中学习嵌入表示,得到每一个样本/>的特征图;
计算模块,将学习模块得到的每一个样本的特征图输入构建的抠图位置预测网络Fc中,得到样本/>被抠掉图像块的位置,并根据/>计算得到一个自监督损失Lse;
分类模块,根据融合模块得到的损失Lsu和计算模块得到的自监督损失Lse联合训练模型进行小样本分类以及抠图位置预测任务,实现增益小样本分类。
与现有技术相比,本发明至少具有以下有益效果:
本发明一种基于自监督增强的小样本图像分类方法,设计了一种抠图位置预测的辅助任务,该任务在GNN基础上构建,首先将输入的支持和查询图像进行随机抠图,然后经过卷积模块提取特征后送入到一个全连接层中,全连接层的任务则是负责预测每个样本抠除的位置以产生一个自监督损失,将原始的支持和查询图像送入到GNN负责进行小样本分类任务产生一个监督分类损失,我们将GNN的原本的分类任务以及抠图位置预测任务一起联合训练,利用自监督辅助任务来增强模型提取有效特征的能力从而提升小样本分类任务的表现。
进一步的,步骤S2中,自监督小样本学习任务设计了一种自监督学习任务,随机抠除输入图像的一块,然后预测被抠除图像块的位置。通过加入自监督学习任务以提升模型提取有效特征的能力,进而提升模型的小样本学习能力。
进一步的,步骤S3中,构建嵌入网络Femb对图中的结点特征进行初始化并根据相邻结点类别的异同来初始化边特征,为后续利用图神经网络传播相邻结点信息以更新图表示做准备。
进一步的,嵌入网络Femb通过多个卷积层提取支持和查询样本的特征表示,并作为图中初始的结点特征,并将初始边特征构建为一个张量,表示相邻结点每一对应像素位置之间的相似程度而不是全局相似度,通过这种构建边特征的方式,使得后续结点特征每一像素位置独立聚合。
进一步的,步骤S4中,将构建的全连接图送入到图神经网络迭代进行结点特征更新及边特征更新,其中边特征的更新通过计算相似性矩阵R来收集相邻结点每一像素位置的语义相似性信息,然后通过边特征转换网络来得到每一层更新后的边特征/>
进一步的,由于矩阵R是一个较为庞大的矩阵,为了更有效地利用其中的语义相似性信息,将矩阵R送入到边特征转换网络来实现边特征的更新,通过这样的方式使得相邻结点每一像素位置单独计算相似度,并根据语义相似性信息突出语义相关区域,进而实现对查询样本更精准的分类。
进一步的,步骤S5中,使用分类任务中常见的交叉熵损失作为小样本分类损失Lsu,通过查询结点的类别概率分布以及每个查询样本的类别标签yi来训练模型对查询样本的类别进行有效预测。
进一步的,步骤S7中,使用分类任务中常见的交叉熵损失作为自监督损失Lse,通过让模型预测每个样本被抠掉的位置来促使模型更有效地提取表征。
进一步的,步骤S8中,我们将小样本分类任务和抠图位置预测任务联合训练,以增强模型提取有效特征的能力进而辅助小样本分类任务,对查询样本实现更精准的分类。
综上所述,本发明设计了一种辅助自监督任务来和小样本分类任务联合训练,这个自监督任务随机抠掉所有输入样本中的一块,同样让图神经网络预测查询结点的类别,并在特征提取网络后加入一个全连接层来预测每个样本被抠掉的位置,如果其可以正确预测每个样本被抠掉的位置则说明模型很好的学习到了全局的结构信息,从而说明模型可以更有效地提取特征。
下面通过附图和实施例,对本发明的技术方案做进一步的详细描述。
附图说明
图1为本发明的整体流程图;
图2为抠图位置预测任务示意图;
图3为边特征更新流程图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
在附图中示出了根据本发明公开实施例的各种结构示意图。这些图并非是按比例绘制的,其中为了清楚表达的目的,放大了某些细节,并且可能省略了某些细节。图中所示出的各种区域、层的形状及它们之间的相对大小、位置关系仅是示例性的,实际中可能由于制造公差或技术限制而有所偏差,并且本领域技术人员根据实际所需可以另外设计具有不同形状、大小、相对位置的区域/层。
本发明提供了一种基于自监督增强的小样本图像分类方法,在基于图神经网络的小样本学习方法基础上提出了一种基于自监督增强的小样本学习方法,基于抠图任务设计了抠图位置预测的自监督学习任务,即随机抠掉样本的一个图像块,然后预测被抠掉图像块的位置,根据预测位置的准确度来产生损失。本发明将辅助自监督任务和小样本分类任务联合训练以增强模型提取有效表征的能力,进而提升分类的精度。
请参阅图1,本发明一种基于自监督增强的小样本图像分类方法,包括以下步骤:
S1、从数据集中采样“N-way k-shot”小样本学习任务T=S∪Q,其中,表示带标签的支持集,xi表示样本,yi表示xi对应的类别标签,支持集共包含N个类,每类有K个样本,查询集Q则表示需要进行类别预测的无标签样本,若查询集中含有r个样本则/>
S2、对步骤S1得到的小样本学习任务T中的每个样本进行随机抠图,即随机抠除每个样本中的一个图像块,得到自监督小样本学习任务
请参阅图2,以3×3图像划分为9块为例,步骤S2具体为:
将输入大小为84×84的图像划分为49个12×12的图像块,则抠图位置预测的任务视为一个49类的分类任务,得到自监督小样本学习任务如下:
其中,表示经过随机抠图后的xi,/>表示xi被抠除图像块的位置。
S3、构建嵌入网络Femb,将小样本学习任务T中的所有样本送入到Femb中学习嵌入表示,得到每一个样本xi的特征图然后构建一个全连接图GT=(V,E),V={v1,...,vN*K+r}表示图中的结点,E={eij;vi,vj∈V}表示图中的边,且表示图中相邻结点vi和vj之间每一对应像素位置即vid∈R1*c与vjd∈R1*c之间的相似性;
S301、构建嵌入网络Femb,嵌入网络Femb包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层、第四卷积层以及输出层;其中嵌入网络Femb的输入为每次从步骤S1中采样的B个小样本学习任务T,B为每批次的大小,输出为任务T中每个样本xi的特征图
S302、将步骤S301得到的特征图作为图GT=(V,E)的初始结点特征然后按照下述方式进行边特征的初始化:
其中,yi和yj分别表示结点vi和vj的类别标签。
S4、构建图神经网络GNN,其共有L层,每一层包括结点特征更新和边特征更新两步;将步骤S3中得到的图GT输入到图神经网络GNN中得到每一层的边特征
S401、输入步骤S3中得到的图GT到GNN进行3层图更新(图2仅画出两层),对于SGNN中的每一层,根据边特征更新结点特征;
其中,||表示级联操作,和/>分别表示第l层的结点特征和边特征,/>表示第l层的结点特征转换网络,包括依次相连接的输入层、第一卷积层、第二卷积层以及输出层,表示网络中可学习的参数。
请参阅图3,步骤S402具体为:
S402、根据更新后的结点特征来更新边特征,首先计算相邻结点特征和/>之间的关系矩阵Rij;
其中,<,>表示向量内积。
请参阅图3,图中所有相邻结点之间的关系矩阵可以视为一个3维的矩阵,用R表示3维的关系矩阵。
然后将R送入一个边特征转换网络中再经过sigmoid操作来得到更新后的边特征/>
其中,边特征转换网络包括依次相连接的输入层、第一卷积层、第二卷积层以及输出层,且第一卷积层为组大小为16的分组卷积,/>表示网络中可学习的参数。
S5、构建边特征融合网络Ffus,将步骤S4中得到的L层的边特征进行级联,输入到Ffus中得到最终边表示/>用最终的边表示预测查询样本xi∈Q的类别,并根据预测结果来得到一个小样本分类损失Lsu;
边特征融合网络Ffus包括依次连接的输入层、卷积层以及输出层,输出层的输出为最终的边表示查询结点vi的类别概率分布计算如下:
其中,xi表示查询集中的样本,yj为支持样本xj的标签。
根据预测的类别分布及查询结点类别标签yi产生第一个小样本分类任务的监督损失Lsu:
其中,Lce表示交叉熵损失。
S6、将步骤S2得到的自监督小样本学习任务送入步骤S3中构建的嵌入网络Femb中学习嵌入表示,得到每一个样本/>的特征图/>
S7、构建抠图位置预测网络Fc,将步骤S6中得到的特征图输入抠图位置预测网络Fc中得到样本/>被抠掉图像块的位置,并根掘/>计算得到一个自监督损失Lse;
构建抠图位置预测网络Fc,抠图位置预测网络fc包括依次相连接的输入层、全连接层以及输出层;根据步骤S6中得到的样本特征预测每个样本被抠掉的位置:
其中,表示对/>被抠掉位置的预测,/>表示/>所有的样本,θc表示全连接层可学习的参数。
根据预测的分布产生一个自监督损失Lse:
其中,表示/>被抠掉的位置标签,Lce表示交叉熵损失。
S8、根据步骤S5得到的小样本分类损失Lsu和步骤S7得到的自监督损失Lse联合训练模型进行小样本分类以及抠图位置预测任务,以增强模型提取有效特征的能力,进而增益小样本分类任务。
总体的损失L为:
L=αLsu+βLse
其中,α为0.8,β为0.2。
本发明再一个实施例中,提供一种基于自监督增强的小样本图像分类系统,该系统能够用于实现上述基于自监督增强的小样本图像分类方法,具体的,该基于自监督增强的小样本图像分类系统包括采样模块、抠图模块、嵌入模块、特征模块、融合模块、学习模块、计算模块以及分类模块。
其中,采样模块,采样小样本学习任务T;
抠图模块,对采样模块得到的小样本学习任务T中的每个样本进行随机抠图,得到自监督小样本学习任务
嵌入模块,将采样模块得到的小样本学习任务T中的所有样本送入构建的嵌入网络Femb中,得到每一个样本xi的特征图;然后构建全连接图GT;
特征模块,将嵌入模块构建的全连接图GT输入构建的图神经网络GNN中,得到每一层的边特征
融合模块,构建边特征融合网络Ffus,将特征模块得到的L层的边特征进行级联,输入边特征融合网络Ffus中得到最终边/>用最终边表示预测查询样本xi∈Q的类别,Q表示查询集,并根据预测结果来得到一个小样本分类损失Lsu;
学习模块,将抠图模块得到的自监督小样本学习任务送入嵌入模块构建的嵌入网络Femb中学习嵌入表示,得到每一个样本/>的特征图;
计算模块,将学习模块得到的每一个样本的特征图输入构建的抠图位置预测网络Fc中,得到样本/>被抠掉图像块的位置,并根据/>计算得到一个自监督损失Lse;
分类模块,根据融合模块得到的损失Lsu和计算模块得到的自监督损失Lse联合训练模型进行小样本分类以及抠图位置预测任务,实现增益小样本分类。
本发明再一个实施例中,提供了一种终端设备,该终端设备包括处理器以及存储器,所述存储器用于存储计算机程序,所述计算机程序包括程序指令,所述处理器用于执行所述计算机存储介质存储的程序指令。处理器可能是中央处理单元(Central ProcessingUnit,CPU),还可以是其他通用处理器、数字信号处理器(Digital Signal Processor、DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable GateArray,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其是终端的计算核心以及控制核心,其适于实现一条或一条以上指令,具体适于加载并执行一条或一条以上指令从而实现相应方法流程或相应功能;本发明实施例所述的处理器可以用于基于自监督增强的小样本图像分类方法的操作,包括:
采样小样本学习任务T;对小样本学习任务T中的每个样本进行随机抠图,得到自监督小样本学习任务将小样本学习任务T中的所有样本送入构建的嵌入网络Femb中,得到每一个样本xi的特征图;然后构建全连接图GT;将构建的全连接图GT输入构建的图神经网络GNN中,得到每一层的边特征/>构建边特征融合网络Ffus,对L层的边特征/>进行级联,输入边特征融合网络Ffus中得到最终边/>用最终边表示预测查询样本xi∈Q的类别,Q表示查询集,并根据预测结果得到小样本分类损失Lsu;将自监督小样本学习任务/>送入构建的嵌入网络Femb中学习嵌入表示,得到每一个样本/>的特征图;将每一个样本/>的特征图输入构建的抠图位置预测网络Fc中,得到样本/>被抠掉图像块的位置,并根据/>计算得到一个自监督损失Lse;根据损失Lsu和自监督损失Lse联合训练模型进行小样本分类以及抠图位置预测任务,实现增益小样本分类。
本发明再一个实施例中,本发明还提供了一种存储介质,具体为计算机可读存储介质(Memory),所述计算机可读存储介质是终端设备中的记忆设备,用于存放程序和数据。可以理解的是,此处的计算机可读存储介质既可以包括终端设备中的内置存储介质,当然也可以包括终端设备所支持的扩展存储介质。计算机可读存储介质提供存储空间,该存储空间存储了终端的操作系统。并且,在该存储空间中还存放了适于被处理器加载并执行的一条或一条以上的指令,这些指令可以是一个或一个以上的计算机程序(包括程序代码)。需要说明的是,此处的计算机可读存储介质可以是高速RAM存储器,也可以是非不稳定的存储器(non-volatile memory),例如至少一个磁盘存储器。
可由处理器加载并执行计算机可读存储介质中存放的一条或一条以上指令,以实现上述实施例中有关基于自监督增强的小样本图像分类方法的相应步骤;计算机可读存储介质中的一条或一条以上指令由处理器加载并执行如下步骤:
采样小样本学习任务T;对小样本学习任务T中的每个样本进行随机抠图,得到自监督小样本学习任务将小样本学习任务T中的所有样本送入构建的嵌入网络Femb中,得到每一个样本xi的特征图;然后构建全连接图GT;将构建的全连接图GT输入构建的图神经网络GNN中,得到每一层的边特征/>构建边特征融合网络Ffus,对L层的边特征/>进行级联,输入边特征融合网络Ffus中得到最终边/>用最终边表示预测查询样本xi∈Q的类别,Q表示查询集,并根据预测结果得到小样本分类损失Lsu;将自监督小样本学习任务/>送入构建的嵌入网络Femb中学习嵌入表示,得到每一个样本/>的特征图;将每一个样本/>的特征图输入构建的抠图位置预测网络Fc中,得到样本/>被抠掉图像块的位置,并根据/>计算得到一个自监督损失Lse;根据损失Lsu和自监督损失Lse联合训练模型进行小样本分类以及抠图位置预测任务,实现增益小样本分类。
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。通常在此处附图中的描述和所示的本发明实施例的组件可以通过各种不同的配置来布置和设计。因此,以下对在附图中提供的本发明的实施例的详细描述并非旨在限制要求保护的本发明的范围,而是仅仅表示本发明的选定实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明的效果可通过以下仿真结果进一步说明
1.仿真条件
本发明仿真的硬件条件为:智能感知与图像理解实验室图形工作站,打在一块显存为12G的GPU;本发明仿真所使用的数据集为miniImageNet数据集。数据集中所有的图片都是大小为84*84的3通道RGB图像,共包含了100类,每一类有大约600张图片。
本发明遵循了目前小样本学习方法的常用划分方式,将其中的64类用于训练,16类用于验证,20类用于测试。
2.仿真内容
利用miniImageNet数据集,在训练时,对于5way-1shot任务,我们将批大小设置为64,其支持集共有5个类别,每个类别有1个样本,并且每类有1个查询样本,所以一共10个样本来构建一个episode。对于5way-5shot任务,将批大小设置为20,其支持集同样有5个类别,但是每类有5个样本,每类同样有1个查询样本,所以一共30个样本来构建一个episode。
在验证阶段,随机从测试集中采样600个小样本分类任务,根据600个任务上的平均准确率来评价其性能。
表1本发明方法在miniImageNet数据集上的实验结果
3.仿真结果分析
从表1看出,本发明方法在miniImageNet上5way-1shot设置下的分类准确率达到了52.26%,在5way-5shot设置下达到了66.55%,较对比方法有了显著的提升。另外,在去掉自监督辅助任务的情况下,性能出现了下降,这证明了本发明提出的抠图位置预测任务的有效性。
综上所述,本发明一种基于自监督增强的小样本图像分类方法及系统,通过加入自监督任务增强了基于图神经网络的小样本学习模型的小样本学习能力,将抠图位置预测任务和小样本分类任务联合训练以增强模型提取有效表征的能力,从而改善模型的分类结果,并通过在miniImageNet上的对比和消融实验证明了本发明的有效性。
本领域内的技术人员应明白,本申请的实施例可提供为方法、系统、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本申请是参照根据本申请实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
以上内容仅为说明本发明的技术思想,不能以此限定本发明的保护范围,凡是按照本发明提出的技术思想,在技术方案基础上所做的任何改动,均落入本发明权利要求书的保护范围之内。
Claims (9)
1.一种基于自监督增强的小样本图像分类方法,其特征在于,包括以下步骤:
S1、采样小样本学习任务T;
S2、对步骤S1得到的小样本学习任务T中的每个样本进行随机抠图,得到自监督小样本学习任务自监督小样本学习任务/>为:
其中,表示经过随机抠图后的xi,/>表示xi被抠除图像块的位置,N为支持样本的类别数,K为支持样本每一类包含的样本,r为查询样本的个数;
S3、将步骤S1得到的小样本学习任务T中的所有样本送入构建的嵌入网络Femb中,得到每一个样本xi的特征图;然后构建全连接图GT;
S4、将步骤S3构建的全连接图GT输入构建的图神经网络GNN中,得到每一层的边特征
S5、构建边特征融合网络Ffus,对步骤S4得到的L层的边特征进行级联,输入边特征融合网络Ffus中得到最终边/>用最终边表示预测查询样本xi∈Q的类别,Q表示查询集,并根据预测结果得到小样本分类损失Lsu;
S6、将步骤S2得到的自监督小样本学习任务送入步骤S3中构建的嵌入网络Femb中学习嵌入表示,得到每一个样本/>的特征图;
S7、将步骤S6得到的每一个样本的特征图输入构建的抠图位置预测网络Fc中,得到样本/>被抠掉图像块的位置,并根据/>计算得到一个自监督损失Lse;
S8、根据步骤S5得到的损失Lsu和步骤S7得到的自监督损失Lse联合训练模型进行小样本分类以及抠图位置预测任务,实现增益小样本分类。
2.根据权利要求1所述的方法,其特征在于,步骤S3具体为:
S301、构建嵌入网络Femb,嵌入网络Femb的输入为每次从步骤S1中采样的B个小样本学习任务T,B为每批次的大小,输出为任务T中每个样本xi的特征图
S302、将步骤S301得到的特征图作为全连接图GT的初始结点特征/>然后进行边特征/>的初始化。
3.根据权利要求2所述的方法,其特征在于,步骤S302中,边特征的初始化具体为:
其中,yi和yj分别表示结点vi和vj的类别标签。
4.根据权利要求1所述的方法,其特征在于,步骤S4具体为:
S401、输入步骤S3中得到的全连接图GT到图神经网络GNN进行更新,对于图神经网络GNN中的每一层,根据边特征更新结点特征;
S402、根据结点特征来新边特征,计算相邻结点特征和/>之间的关系矩阵Rij,然后将Rij送入一个边特征转换网络/>中再经过sigmoid操作来得到更新后的边特征/>
5.根据权利要求4所述的方法,其特征在于,步骤S402中,更新后的边特征具体为:
其中,表示网络中可学习的参数。
6.根据权利要求1所述的方法,其特征在于,步骤S5中,小样本分类损失Lsu为:
其中,Lce表示交叉熵损失,xi表示查询集中的样本,yj为查询样本xi的标签,为查询结点vi的类别概率分布。
7.根据权利要求1所述的方法,其特征在于,步骤S7中,自监督损失Lse为:
其中,表示/>被抠掉的位置标签,Lce表示交叉熵损失,/>为预测位置的概率分布。
8.根据权利要求1所述的方法,其特征在于,步骤S8中,小样本分类损失Lsu和自监督损失Lse将小样本分类任务与抠图位置预测任务联合训练的总体损失L为:
L=αLsu+βLse
其中,α为0.8,β为0.2。
9.一种基于自监督增强的小样本图像分类系统,其特征在于,包括:
采样模块,采样小样本学习任务T;
抠图模块,对采样模块得到的小样本学习任务T中的每个样本进行随机抠图,得到自监督小样本学习任务自监督小样本学习任务/>为:
其中,表示经过随机抠图后的xi,/>表示xi被抠除图像块的位置,N为支持样本的类别数,K为支持样本每一类包含的样本,r为查询样本的个数;
嵌入模块,将采样模块得到的小样本学习任务T中的所有样本送入构建的嵌入网络Femb中,得到每一个样本xi的特征图;然后构建全连接图GT;
特征模块,将嵌入模块构建的全连接图GT输入构建的图神经网络GNN中,得到每一层的边特征
融合模块,构建边特征融合网络Ffus,将特征模块得到的L层的边特征进行级联,输入边特征融合网络Ffus中得到最终边/>用最终边表示预测查询样本xi∈Q的类别,Q表示查询集,并根据预测结果来得到一个小样本分类损失Lsu;
学习模块,将抠图模块得到的自监督小样本学习任务送入嵌入模块构建的嵌入网络Femb中学习嵌入表示,得到每一个样本/>的特征图;
计算模块,将学习模块得到的每一个样本的特征图输入构建的抠图位置预测网络Fc中,得到样本/>被抠掉图像块的位置,并根据/>计算得到一个自监督损失Lse;
分类模块,根据融合模块得到的损失Lsu和计算模块得到的自监督损失Lse联合训练模型进行小样本分类以及抠图位置预测任务,实现增益小样本分类。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110657337.XA CN113378937B (zh) | 2021-06-11 | 2021-06-11 | 一种基于自监督增强的小样本图像分类方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110657337.XA CN113378937B (zh) | 2021-06-11 | 2021-06-11 | 一种基于自监督增强的小样本图像分类方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113378937A CN113378937A (zh) | 2021-09-10 |
CN113378937B true CN113378937B (zh) | 2023-08-11 |
Family
ID=77574436
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110657337.XA Active CN113378937B (zh) | 2021-06-11 | 2021-06-11 | 一种基于自监督增强的小样本图像分类方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113378937B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113989556B (zh) * | 2021-10-27 | 2024-04-09 | 南京大学 | 一种小样本医学影像分类方法和系统 |
CN114863234A (zh) * | 2022-04-29 | 2022-08-05 | 华侨大学 | 一种基于拓扑结构保持的图表示学习方法及系统 |
CN115100390B (zh) * | 2022-08-24 | 2022-11-18 | 华东交通大学 | 一种联合对比学习与自监督区域定位的图像情感预测方法 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019136946A1 (zh) * | 2018-01-15 | 2019-07-18 | 中山大学 | 基于深度学习的弱监督显著性物体检测的方法及系统 |
CN112348792A (zh) * | 2020-11-04 | 2021-02-09 | 广东工业大学 | 一种基于小样本学习和自监督学习的x光胸片图像分类方法 |
CN112766378A (zh) * | 2021-01-19 | 2021-05-07 | 北京工商大学 | 一种专注细粒度识别的跨域小样本图像分类模型方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11308353B2 (en) * | 2019-10-23 | 2022-04-19 | Adobe Inc. | Classifying digital images in few-shot tasks based on neural networks trained using manifold mixup regularization and self-supervision |
-
2021
- 2021-06-11 CN CN202110657337.XA patent/CN113378937B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019136946A1 (zh) * | 2018-01-15 | 2019-07-18 | 中山大学 | 基于深度学习的弱监督显著性物体检测的方法及系统 |
CN112348792A (zh) * | 2020-11-04 | 2021-02-09 | 广东工业大学 | 一种基于小样本学习和自监督学习的x光胸片图像分类方法 |
CN112766378A (zh) * | 2021-01-19 | 2021-05-07 | 北京工商大学 | 一种专注细粒度识别的跨域小样本图像分类模型方法 |
Non-Patent Citations (1)
Title |
---|
多级注意力特征网络的小样本学习;汪荣贵;韩梦雅;杨娟;薛丽霞;胡敏;;电子与信息学报(第03期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN113378937A (zh) | 2021-09-10 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113378937B (zh) | 一种基于自监督增强的小样本图像分类方法及系统 | |
CN109584337B (zh) | 一种基于条件胶囊生成对抗网络的图像生成方法 | |
CN110910391B (zh) | 一种双模块神经网络结构视频对象分割方法 | |
CN112686850B (zh) | 基于空间位置和原型网络的ct图像的少样本分割方法和系统 | |
CN113378938B (zh) | 一种基于边Transformer图神经网络的小样本图像分类方法及系统 | |
CN107679501B (zh) | 一种基于标签自提纯的深度学习方法 | |
CN115359074B (zh) | 基于超体素聚类及原型优化的图像分割、训练方法及装置 | |
CN113420827A (zh) | 语义分割网络训练和图像语义分割方法、装置及设备 | |
CN115797632B (zh) | 一种基于多任务学习的图像分割方法 | |
CN111782804B (zh) | 基于TextCNN同分布文本数据选择方法、系统及存储介质 | |
CN113673482B (zh) | 基于动态标签分配的细胞抗核抗体荧光识别方法及系统 | |
CN111241933A (zh) | 一种基于通用对抗扰动的养猪场目标识别方法 | |
CN105701225A (zh) | 一种基于统一关联超图规约的跨媒体检索方法 | |
CN116091946A (zh) | 一种基于YOLOv5的无人机航拍图像目标检测方法 | |
Dhawan et al. | Deep Learning Based Sugarcane Downy Mildew Disease Detection Using CNN-LSTM Ensemble Model for Severity Level Classification | |
CN114333062A (zh) | 基于异构双网络和特征一致性的行人重识别模型训练方法 | |
CN113223037A (zh) | 一种面向大规模数据的无监督语义分割方法及系统 | |
CN113378934B (zh) | 一种基于语义感知图神经网络的小样本图像分类方法及系统 | |
Das et al. | Object Detection on Scene Images: A Novel Approach | |
Okawa et al. | Detection of abnormal fish by image recognition using fine-tuning | |
CN110851633B (zh) | 一种实现同时定位和哈希的细粒度图像检索方法 | |
CN116343104B (zh) | 视觉特征与向量语义空间耦合的地图场景识别方法及系统 | |
CN116597419B (zh) | 一种基于参数化互近邻的车辆限高场景识别方法 | |
김종태 | SF Network | |
Ramanathan et al. | QUICKSAL: A small and sparse visual saliency model for efficient inference in resource constrained hardware |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |