CN113378938A - 一种基于边Transformer图神经网络的小样本图像分类方法及系统 - Google Patents

一种基于边Transformer图神经网络的小样本图像分类方法及系统 Download PDF

Info

Publication number
CN113378938A
CN113378938A CN202110657352.4A CN202110657352A CN113378938A CN 113378938 A CN113378938 A CN 113378938A CN 202110657352 A CN202110657352 A CN 202110657352A CN 113378938 A CN113378938 A CN 113378938A
Authority
CN
China
Prior art keywords
edge
layer
sample
neural network
graph
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110657352.4A
Other languages
English (en)
Other versions
CN113378938B (zh
Inventor
刘芳
张瀚
马文萍
李玲玲
李鹏芳
杨苗苗
刘洋
刘旭
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Xidian University
Original Assignee
Xidian University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Xidian University filed Critical Xidian University
Priority to CN202110657352.4A priority Critical patent/CN113378938B/zh
Publication of CN113378938A publication Critical patent/CN113378938A/zh
Application granted granted Critical
Publication of CN113378938B publication Critical patent/CN113378938B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于边Transformer图神经网络的小样本图像分类方法及系统,引入transformer模型来对图中的边特征进行更新,通过将结点之间的差值特征图拆分成特征块序列输入到transformer模块得到更新后的边特征,使得每一个像素位置会被分配不同的注意力权重以突出关键区域。本发明的思想是利用transformer中的自注意力机制来自动聚焦到用于衡量结点间相似性的关键区域,从而达到抑制背景信息并突出关键区域的目的。本发明在miniImageNet数据集上进行的对比实验证明了本发明可以提高小样本图像分类的精度。

Description

一种基于边Transformer图神经网络的小样本图像分类方法 及系统
技术领域
本发明属于计算机视觉技术领域,具体涉及一种基于边Transformer图神经网络的小样本图像分类方法及系统。
背景技术
近年来,得益于计算机算力的飞速提升,深度学习成为了人工智能领域研究的热门。然而深度学习繁荣发展的背后是大规模人工标注的数据集的支撑,并且越加复杂的网络就越需要更加庞大的数据集来训练。但是在一些特殊的领域数据是非常匮乏的,如医学中罕见病例的判别,其现有的有限医学图像是远远不够用于训练一个良好的深度模型的。这时就希望模型可以减少对数据的依赖,像人类一样可以进行快速的学习,那么将会大大减少数据的人工标注成本,基于此小样本学习渐渐得到了许多研究者的关注。小样本学习顾名思义就是在带标注数据不充足的情况下进行的学习任务,一个优秀的小样本学习模型通过一定量任务的训练后,不需要进行额外的训练就可以泛化到新的任务上。现有的小样本学习方法大致可以分为基于度量学习、基于元学习、基于数据增强以及基于图神经网络四种。
现有的基于图神经网络的小样本学习模型基于全局相似度来进行结点特征的聚合,这种方式会聚合许多背景信息进而引起语义上的歧义,为了解决这一问题,受transformer模型的启发,本发明提出一种利用transformer感知关键区域的小样本学习方法,抛弃了CNN转而使用transformer编码器来对GNN中的边特征进行更新。动机来自于transformer模型结构中含有的自注意力层,使得模型天生具有感知关键区域的能力,所以本章的方法利用其中的自注意力机制来自动学习为不同的像素位置分配不同的注意力,进而在结点特征聚合的时候给予关键区域更多的关注。
发明内容
本发明所要解决的技术问题在于针对上述现有技术中的不足,提供一种基于边Transformer图神经网络的小样本图像分类方法及系统,通过利用transformer模型结构中含有的自注意力层,使得模型天生具有感知关键区域的能力,进而增强模型在小样本图像分类任务上的表现。
本发明采用以下技术方案:
一种基于边Transformer图神经网络的小样本图像分类方法,包括以下步骤:
S1、采样小样本学习任务T;
S2、将步骤S1得到的小样本学习任务T中的每一个样本xi送入构建的嵌入网络Femb中,得到每一个样本的特征图fi
S3、构建一个全连接图GT,将步骤S2中每一个样本的特征图fi作为初始的结点特征
Figure BDA0003113663250000021
并根据查询样本的标签初始化边特征
Figure BDA0003113663250000022
S4、将步骤S3构建的全连接图GT输入到由L层边Transformer图神经网络构成的ETGNN中迭代进行结点特征更新,利用Transformer图神经网络的边transformer模块进行边特征更新,得到每一层的边特征
Figure BDA0003113663250000023
S5、对步骤S4得到的L层的边特征
Figure BDA0003113663250000024
进行级联,然后输入到构建的边特征融合网络Ffus中,得到最终边
Figure BDA0003113663250000025
并根据最终边
Figure BDA0003113663250000026
以及支持样本的类别yj得到查询结点vi的类别概率分布
Figure BDA0003113663250000027
S6、根据步骤S5得到的查询结点的类别概率分布
Figure BDA0003113663250000028
以及查询结点的标签yi计算小样本分类损失Lfl,端到端地训练嵌入网络Femb以及边Transformer图神经网络ETGNN,测试时利用类别概率分布
Figure BDA0003113663250000031
对查询样本进行类别预测以实现小样本分类。
具体的,步骤S2中,嵌入网络Femb的输入为每次从步骤S1中采样的B个小样本学习任务T,B为每批次的大小,输出为任务T中每个样本xi的特征图
Figure BDA0003113663250000032
进一步的,嵌入网络Femb包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层、第四卷积层以及输出层。
具体的,步骤S3中,全连接图GT=(V,E),V={v1,…,vN**K+r}表示图中的结点,E={eij;vi,vj∈V}表示图中的边,且
Figure BDA0003113663250000033
表示图中相邻结点vi和vj之间每一对应像素位置即vid∈R1*与vjd∈R1**之间的相似性;将步骤S2得到的特征图
Figure BDA0003113663250000034
作为图GT=(V,E)的初始结点特征
Figure BDA0003113663250000035
然后进行边特征初始化。
进一步的,边特征初始化具体为:
Figure BDA0003113663250000036
其中,yi和yj分别表示结点vi和vj的类别标签。
具体的,步骤S4具体为:
S401、输入步骤S3中得到的图GT到边transformer图神经网络SGNN进行3层图更新,对于边transformer图神经网络SGNN中的每一层,根据边特征更新结点特征;
S402、将每一层更新后的结点特征输入到边Transformer模块更新边特征,通过边Transformer模块计算相邻结点特征间的差值特征图
Figure BDA0003113663250000037
然后将差值特征图
Figure BDA0003113663250000041
进行拆分,得到w*h个差值特征块组成的序列pl,引入可学习的位置编码
Figure BDA0003113663250000042
然后将序列pl和位置编码
Figure BDA0003113663250000043
进行级联得到序列
Figure BDA0003113663250000044
最后将序列
Figure BDA0003113663250000045
送入到边Transformer模块得到更新后的边特征
Figure BDA0003113663250000046
进一步的,边Transformer模块包含两个子层,每一个子层后接一个LN层进行标准化,第一个子层为自注意力层,第二个子层为前馈网络层;
第一个子层中,首先得到序列
Figure BDA0003113663250000047
的每个位置的查询向量
Figure BDA0003113663250000048
键向量
Figure BDA0003113663250000049
以及值向量
Figure BDA00031136632500000410
然后计算自注意力
Figure BDA00031136632500000411
将输入序列
Figure BDA00031136632500000412
和自注意力
Figure BDA00031136632500000413
做一个残差连接
Figure BDA00031136632500000414
第二个子层中,输出的边特征维度
Figure BDA00031136632500000415
为:
Figure BDA00031136632500000416
其中,MLP为一个包含两个全连接层的多层感知机。
具体的,步骤S5中,查询结点vi的类别概率分布
Figure BDA00031136632500000417
计算如下:
Figure BDA00031136632500000418
其中,xi表示查询集中的样本,xj表示支持集中的样本,yj为支持样本xj的标签,
Figure BDA00031136632500000421
为最终边。
具体的,步骤S6中,小样本分类损失Lfl
Figure BDA00031136632500000419
其中,Lce表示交叉熵损失,
Figure BDA00031136632500000420
为查询结点vi的类别概率分布,yj为支持样本xj的标签。
本发明的另一技术方案是,一种基于边Transformer图神经网络的小样本图像分类系统,包括:
采样模块,采样小样本学习任务T;
特征模块,将采样模块得到的小样本学习任务T中的每一个样本xi送入构建的嵌入网络Femb中,得到每一个样本的特征图fi
全连接模块,构建一个全连接图GT,将特征模块中每一个样本的特征图fi作为初始的结点特征
Figure BDA0003113663250000051
并根据查询样本的标签初始化边特征
Figure BDA0003113663250000052
神经网络模块,将全连接模块构建的全连接图GT输入到由L层边Transformer图神经网络构成的ETGNN中迭代进行结点特征更新,利用Transformer图神经网络的边transformer模块进行边特征更新,得到每一层的边特征
Figure BDA0003113663250000053
融合模块,对神经网络模块得到的L层的边特征
Figure BDA0003113663250000054
进行级联,然后输入到构建的边特征融合网络Ffus中,得到最终边
Figure BDA0003113663250000055
并根据最终边
Figure BDA0003113663250000056
以及支持样本的类别yj得到查询结点vi的类别概率分布
Figure BDA0003113663250000057
分类模块,根据融合模块得到的查询结点的类别概率分布
Figure BDA0003113663250000058
以及查询结点的标签yi计算小样本分类损失Lfl,端到端地训练嵌入网络Femb以及边Transformer图神经网络ETGNN,在测试时利用类别概率分布
Figure BDA0003113663250000059
对查询样本进行类别预测以实现小样本分类。
与现有技术相比,本发明至少具有以下有益效果:
本发明一种基于边Transformer图神经网络的小样本图像分类方法,构建边transformer图神经网络ETGNN,为了感知关键区域以减少背景引起的歧义,基于transformer具有纵观全局来赋予各个区域不同注意力的能力,提出了基于边transformer图神经网络的小样本学习方法,抑制背景对分类结果的影响,提高模型的小样本学习能力。
进一步的,嵌入网络Femb是一个由卷积模块构成的浅层网络,包括依次连接的输入层、第一卷积层、第二卷积层、第三卷积层、第四卷积层以及输出层,浅层的网络有利于模型在面对新的小样本学习任务时快速泛化。
进一步的,嵌入网络Femb通过多个卷积层提取支持和查询样本的特征表示,并作为图中初始的结点特征,并将初始边特征
Figure BDA0003113663250000061
构建为一个张量,表示相邻结点每一对应像素位置之间的相似程度而不是全局相似度,通过这种构建边特征的方式,使得后续结点特征每一像素位置独立聚合。
进一步的,以从嵌入网络Femb中提取的特征作为图中的初始结点特征,并根据相邻结点类别的异同来初始化边特征,为后续利用图神经网络传播相邻结点信息以更新图表示做准备。
进一步的,初始边特征
Figure BDA0003113663250000062
构建为一个张量,表示相邻结点每一对应像素位置之间的相似程度而不是全局相似度,通过这种构建边特征的方式,使得后续结点特征每一像素位置独立聚合。
进一步的,将构建的全连接图送入到边transformer图神经网络迭代进行结点特征更新及边特征更新,其中边特征的更新将结点之间的差值特征图拆分成序列并加入位置编码后作为transformer编码器的输入来更新边的表示,利用transformer中的自注意力机制来帮助模型聚焦到用于衡量相似性的关键区域。
进一步的,transformer编码器由自注意力层和前馈网络层构成,自注意力层负责计算输入序列的自注意力,前馈网络层将自注意力映射为更新后的边特征。这种更新边特征的方式利用自注意力机制来自动学习为不同的像素位置分配不同的注意力,进而在结点特征聚合的时候给予关键区域更多的关注。
进一步的,根据最终的边表示
Figure BDA0003113663250000063
以及支持样本的类别yj得到查询结点vi的类别概率分布
Figure BDA0003113663250000064
此概率分布在训练时用于计算损失,在测试时则用于对查询样本进行类别预测。
进一步的,使用分类任务中常见的交叉熵损失作为小样本分类损失Lsu,通过查询结点的类别概率分布
Figure BDA0003113663250000071
以及每个查询样本的类别标签yi来训练模型对查询样本的类别进行有效预测。
综上所述,本发明引入边transformer模块对图中的边特征进行更新,通过将结点之间的差值特征图拆分成特征块序列输入到transformer模块得到更新后的边特征,利用边transformer模块中的自注意力机制来自动聚焦到用于衡量结点间相似性的关键区域,从而达到抑制背景信息并突出关键区域的目的。
下面通过附图和实施例,对本发明的技术方案做进一步的详细描述。
附图说明
图1为本发明的实现流程图;
图2为边transformer模块(ETM)图;
图3为transformer编码器结构细节图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
在附图中示出了根据本发明公开实施例的各种结构示意图。这些图并非是按比例绘制的,其中为了清楚表达的目的,放大了某些细节,并且可能省略了某些细节。图中所示出的各种区域、层的形状及它们之间的相对大小、位置关系仅是示例性的,实际中可能由于制造公差或技术限制而有所偏差,并且本领域技术人员根据实际所需可以另外设计具有不同形状、大小、相对位置的区域/层。
本发明提供了一种基于边Transformer图神经网络的小样本图像分类方法,能够减少背景引入的歧义,突出结点中关键区域之间的关系。通过将结点特征间的差值特征图拆分成序列送入到一个transformer编码器来得到更新后的边特征,利用transformer中的自注意力机制来帮助模型感知到需要关注的关键区域从而达到抑制背景的目的,得益于transformer中的位置编码,其输出的边还综合考虑了空间位置信息,这是目前其他基于的小样本学习方法中所没有的。
请参阅图1,本发明一种基于边Transformer图神经网络的小样本图像分类方法,包括以下步骤:
S1、从数据集中采样“N-way k-shot”小样本学习任务T=S∪Q,其中,
Figure BDA0003113663250000081
表示带标签的支持集,xi表示样本,yi表示xi对应的类别标签,支持集共包含N个类,每类有K个样本,查询集Q表示需要进行类别预测的无标签样本,若查询集中含有r个样本则
Figure BDA0003113663250000091
S2、构建嵌入网络Femb,将T中的所有样本送入到Femb中学习嵌入表示,得到每一个样本xi的特征图
Figure BDA0003113663250000092
构建嵌入网络Femb,嵌入网络Femb包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层、第四卷积层以及输出层;其中嵌入网络Femb的输入为每次从步骤S1中采样的B个小样本学习任务T,B为每批次的大小,输出为任务T中每个样本xi的特征图
Figure BDA0003113663250000093
S3、构建一个全连接图GT=(V,E),V={v1,…,vN*K+r}表示图中的结点,E={eij;vi,vj∈V}表示图中的边,且
Figure BDA0003113663250000094
表示图中相邻结点vi和vj之间每一对应像素位置即vid∈R1*c与vjd∈R1*c之间的相似性;
将步骤S2得到的特征图
Figure BDA0003113663250000095
作为图GT=(V,E)的初始结点特征
Figure BDA0003113663250000096
然后按照下述方式进行边特征的初始化:
Figure BDA0003113663250000097
其中,yi和yj分别表示结点vi和vj的类别标签。
S4、构建边transformer图神经网络ETGNN,边transformer图神经网络ETGNN共有L层,每一层包括结点特征更新和边特征更新两步,边特征的更新是由边transformer模块ETM实现;将步骤S3中得到的图GT输入到边transformer图神经网络ETGNN中得到每一层的边特征
Figure BDA0003113663250000098
请参阅图1和图2,步骤S4具体为:
S401、输入步骤S3中得到的全连接图GT到边transformer图神经网络SGNN进行3层图更新(图1仅画出两层),对于边transformer图神经网络SGNN中的每一层,根据边特征更新结点特征;
Figure BDA0003113663250000101
其中,||表示级联操作,
Figure BDA0003113663250000102
Figure BDA0003113663250000103
分别表示第l层的结点特征和边特征,
Figure BDA0003113663250000104
表示第l层的结点特征转换网络,包括依次相连接的输入层、第一卷积层、第二卷积层以及输出层,
Figure BDA0003113663250000105
表示网络中可学习的参数。
S402、将每一层更新后的结点特征输入到边Transformer模块ETM来更新边特征,ETM首先计算相邻结点特征间的差值特征图,以第l层的两个相邻结点的特征
Figure BDA0003113663250000106
Figure BDA0003113663250000107
为例,差值特征图
Figure BDA0003113663250000108
计算为:
Figure BDA0003113663250000109
然后将
Figure BDA00031136632500001010
进行拆分,得到w*h个差值特征块组成的序列
Figure BDA00031136632500001011
并引入了可学习的位置编码
Figure BDA00031136632500001012
然后将序列pl和位置编码
Figure BDA00031136632500001013
进行级联得到序列
Figure BDA00031136632500001014
Figure BDA00031136632500001015
最后将
Figure BDA00031136632500001016
送入到一个边Transformer模块得到更新后的边特征
Figure BDA00031136632500001017
请参阅图3,边Transformer模块具体为:
边Transformer模块包含两个子层,每一个子层后会接一个LN层进行标准化,其中第一个子层为自注意力层,首先要得到序列
Figure BDA00031136632500001018
的每个位置的查询向量
Figure BDA00031136632500001019
键向量
Figure BDA00031136632500001020
以及值向量
Figure BDA00031136632500001021
Figure BDA00031136632500001022
其中,
Figure BDA00031136632500001023
为查询矩阵,
Figure BDA00031136632500001024
为键矩阵,
Figure BDA00031136632500001025
为值矩阵,且Wqkγ∈R(c+t)×3r是需要学习的参数。然后计算自注意力:
Figure BDA00031136632500001026
其中,
Figure BDA0003113663250000111
Wout∈Rr×(c+t)是需要学习的参数,在得到
Figure BDA0003113663250000112
之后,我们将输入序列
Figure BDA0003113663250000113
和其做一个残差连接:
Figure BDA0003113663250000114
第二个子层为前馈网络层,负责将自注意力层的输出映射到边Transformer模块最终的输出即更新后的边特征,这一过程表述为:
Figure BDA0003113663250000115
其中,输出的边特征维度为
Figure BDA0003113663250000116
S5、构建边特征融合网络Ffus,将步骤S4中得到的L层的边特征
Figure BDA0003113663250000117
进行级联,输入到Ffus中得到最终边表示
Figure BDA0003113663250000118
边特征融合网络Ffus包括依次连接的输入层、卷积层以及输出层,输出层的输出为最终的边表示
Figure BDA0003113663250000119
则查询结点vi的类别概率分布计算如下:
Figure BDA00031136632500001110
其中,xi表示查询集中的样本,yj为支持样本xj的标签。
S6、根据S5中得到的查询结点的类别概率分布
Figure BDA00031136632500001111
及查询结点类别标签yi来计算小样本分类损失Lfl,端到端地训练嵌入网络Femb以及边Transformer图神经网络ETGNN,测试时利用类别概率分布
Figure BDA00031136632500001112
对查询样本进行类别预测以实现小样本分类。
小样本分类损失Lfl
Figure BDA00031136632500001113
其中,Lce表示交叉熵损失。
本发明再一个实施例中,提供一种基于边Transformer图神经网络的小样本图像分类系统,该系统能够用于实现上述基于边Transformer图神经网络的小样本图像分类方法,具体的,该基于边Transformer图神经网络的小样本图像分类系统包括采样模块、特征模块、全连接模块、神经网络模块、融合模块以及分类模块。
其中,采样模块,采样小样本学习任务T;
特征模块,将采样模块得到的小样本学习任务T中的每一个样本xi送入构建的嵌入网络Femb中,得到每一个样本的特征图fi
全连接模块,构建一个全连接图GT,将特征模块中每一个样本的特征图fi作为初始的结点特征
Figure BDA0003113663250000121
并根据查询样本的标签初始化边特征
Figure BDA0003113663250000122
神经网络模块,将全连接模块构建的全连接图GT输入到由L层边Transformer图神经网络构成的ETGNN中迭代进行结点特征更新,利用Transformer图神经网络的边transformer模块进行边特征更新,得到每一层的边特征
Figure BDA0003113663250000123
融合模块,对神经网络模块得到的L层的边特征
Figure BDA0003113663250000124
进行级联,然后输入到构建的边特征融合网络Ffus中,得到最终边
Figure BDA0003113663250000125
并根据最终边
Figure BDA0003113663250000126
以及支持样本的类别yj得到查询结点vi的类别概率分布
Figure BDA0003113663250000127
分类模块,根据融合模块得到的查询结点的类别概率分布
Figure BDA0003113663250000128
以及查询结点的标签yi计算小样本分类损失Lfl,端到端地训练嵌入网络Femb以及边Transformer图神经网络ETGNN,测试时利用类别概率分布
Figure BDA0003113663250000129
对查询样本进行类别预测以实现小样本分类。
本发明再一个实施例中,提供了一种终端设备,该终端设备包括处理器以及存储器,所述存储器用于存储计算机程序,所述计算机程序包括程序指令,所述处理器用于执行所述计算机存储介质存储的程序指令。处理器可能是中央处理单元(Central ProcessingUnit,CPU),还可以是其他通用处理器、数字信号处理器(Digital Signal Processor、DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable GateArray,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其是终端的计算核心以及控制核心,其适于实现一条或一条以上指令,具体适于加载并执行一条或一条以上指令从而实现相应方法流程或相应功能;本发明实施例所述的处理器可以用于基于边Transformer图神经网络的小样本图像分类方法的操作,包括:
采样小样本学习任务T;将小样本学习任务T中的每一个样本xi送入构建的嵌入网络Femb中,得到每一个样本的特征图fi;构建一个全连接图GT,将每一个样本的特征图fi作为初始的结点特征
Figure BDA0003113663250000131
并根据查询样本的标签初始化边特征
Figure BDA0003113663250000132
将全连接图GT输入到由L层边Transformer图神经网络构成的ETGNN中迭代进行结点特征更新,利用Transformer图神经网络的边transformer模块进行边特征更新,得到每一层的边特征
Figure BDA0003113663250000133
对L层的边特征
Figure BDA0003113663250000134
进行级联,然后输入到构建的边特征融合网络Ffus中,得到最终边
Figure BDA0003113663250000135
并根据最终边
Figure BDA0003113663250000136
以及支持样本的类别yj得到查询结点vi的类别概率分布
Figure BDA0003113663250000137
根据查询结点的类别概率分布
Figure BDA0003113663250000138
以及查询结点的标签yi计算小样本分类损失Lfl,端到端地训练嵌入网络Femb以及边Transformer图神经网络ETGNN,测试时利用类别概率分布
Figure BDA0003113663250000139
对查询样本进行类别预测以实现小样本分类。
本发明再一个实施例中,本发明还提供了一种存储介质,具体为计算机可读存储介质(Memory),所述计算机可读存储介质是终端设备中的记忆设备,用于存放程序和数据。可以理解的是,此处的计算机可读存储介质既可以包括终端设备中的内置存储介质,当然也可以包括终端设备所支持的扩展存储介质。计算机可读存储介质提供存储空间,该存储空间存储了终端的操作系统。并且,在该存储空间中还存放了适于被处理器加载并执行的一条或一条以上的指令,这些指令可以是一个或一个以上的计算机程序(包括程序代码)。需要说明的是,此处的计算机可读存储介质可以是高速RAM存储器,也可以是非不稳定的存储器(non-volatile memory),例如至少一个磁盘存储器。
可由处理器加载并执行计算机可读存储介质中存放的一条或一条以上指令,以实现上述实施例中有关基于边Transformer图神经网络的小样本图像分类方法的相应步骤;计算机可读存储介质中的一条或一条以上指令由处理器加载并执行如下步骤:
采样小样本学习任务T;将小样本学习任务T中的每一个样本xi送入构建的嵌入网络Femb中,得到每一个样本的特征图fi;构建一个全连接图GT,将每一个样本的特征图fi作为初始的结点特征
Figure BDA0003113663250000141
并根据查询样本的标签初始化边特征
Figure BDA0003113663250000142
将全连接图GT输入到由L层边Transformer图神经网络构成的ETGNN中迭代进行结点特征更新,利用Transformer图神经网络的边transformer模块进行边特征更新,得到每一层的边特征
Figure BDA0003113663250000143
对L层的边特征
Figure BDA0003113663250000144
进行级联,然后输入到构建的边特征融合网络Ffus中,得到最终边
Figure BDA0003113663250000145
并根据最终边
Figure BDA0003113663250000146
以及支持样本的类别yj得到查询结点vi的类别概率分布
Figure BDA0003113663250000147
根据查询结点的类别概率分布
Figure BDA0003113663250000148
以及查询结点的标签yi计算小样本分类损失Lfl,端到端地训练嵌入网络Femb以及边Transformer图神经网络ETGNN,测试时利用类别概率分布
Figure BDA0003113663250000149
对查询样本进行类别预测以实现小样本分类。
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。通常在此处附图中的描述和所示的本发明实施例的组件可以通过各种不同的配置来布置和设计。因此,以下对在附图中提供的本发明的实施例的详细描述并非旨在限制要求保护的本发明的范围,而是仅仅表示本发明的选定实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明的效果可通过以下仿真结果进一步说明
1.仿真条件
本发明仿真的硬件条件为:智能感知与图像理解实验室图形工作站,打在一块显存为12G的GPU;本发明仿真所使用的数据集为miniImageNet数据集。数据集中所有的图片都是大小为84*84的3通道RGB图像,共包含了100类,每一类有大约600张图片。
本发明遵循了目前小样本学习方法的常用划分方式,将其中的64类用于训练,16类用于验证,20类用于测试。
2.仿真内容
利用miniImageNet数据集,在训练时,对于5way-1shot任务,我们将批大小设置为64,其支持集共有5个类别,每个类别有1个样本,并且每类有1个查询样本,所以一共10个样本来构建一个episode。对于5way-5shot任务,将批大小设置为20,其支持集同样有5个类别,但是每类有5个样本,每类同样有1个查询样本,所以一共30个样本来构建一个episode。
在验证阶段,随机从测试集中采样600个小样本分类任务,根据600个任务上的平均准确率来评价其性能。表1给出了本发明方法和其他一些小样本学习方法的对比实验结果。
表1本发明方法在miniImageNet数据集上的对比实验结果
模型名称 5way-1shot 5way-5shot
MN 46.60% 55.30%
PN 46.14% 65.77%
RN 50.44% 65.32%
GNN 50.33% 66.41%
本发明方法 51.75% 66.47%
3.仿真结果分析
从表1可以看出,本发明方法在miniImageNet上5way-1shot设置下的分类准确率达到了51.75%,在5way-5shot设置下达到了66.47%,较对比方法有了显著的提升。
综上所述,本发明一种基于边Transformer图神经网络的小样本图像分类方法及系统,将图中相邻结点特征的差值特征图分成序列,然后在加入位置编码后作为原始的输入序列送入到transformer编码器中学习得到更新后的边特征,这样得到的边不仅考虑了特征图的空间位置关系,并且得益于transformer模块中的自注意力机制,每一个像素位置会被分配不同的注意力权重以突出关键区域,从而提升模型的性能。本发明在miniImageNet数据集上的对比实验证明了ETGNN的可以提高小样本图像分类的精度。
本领域内的技术人员应明白,本申请的实施例可提供为方法、系统、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本申请是参照根据本申请实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
以上内容仅为说明本发明的技术思想,不能以此限定本发明的保护范围,凡是按照本发明提出的技术思想,在技术方案基础上所做的任何改动,均落入本发明权利要求书的保护范围之内。

Claims (10)

1.一种基于边Transformer图神经网络的小样本图像分类方法,其特征在于,包括以下步骤:
S1、采样小样本学习任务T;
S2、将步骤S1得到的小样本学习任务T中的每一个样本xi送入构建的嵌入网络Femb中,得到每一个样本的特征图fi
S3、构建一个全连接图GT,将步骤S2中每一个样本的特征图fi作为初始的结点特征
Figure FDA0003113663240000011
并根据查询样本的标签初始化边特征
Figure FDA0003113663240000012
S4、将步骤S3构建的全连接图GT输入到由L层边Transformer图神经网络构成的ETGNN中迭代进行结点特征更新,利用Transformer图神经网络的边transformer模块进行边特征更新,得到每一层的边特征
Figure FDA0003113663240000013
S5、对步骤S4得到的L层的边特征
Figure FDA0003113663240000014
进行级联,然后输入到构建的边特征融合网络Ffus中,得到最终边
Figure FDA0003113663240000015
并根据最终边
Figure FDA0003113663240000016
以及支持样本的类别yj得到查询结点vi的类别概率分布
Figure FDA0003113663240000017
S6、根据步骤S5得到的查询结点的类别概率分布
Figure FDA0003113663240000018
以及查询结点的标签yi计算小样本分类损失Lfl,端到端地训练嵌入网络Femb以及边Transformer图神经网络ETGNN,测试时利用类别概率分布
Figure FDA0003113663240000019
对查询样本进行类别预测以实现小样本分类。
2.根据权利要求1所述的方法,其特征在于,步骤S2中,嵌入网络Femb的输入为每次从步骤S1中采样的B个小样本学习任务T,B为每批次的大小,输出为任务T中每个样本xi的特征图
Figure FDA00031136632400000110
3.根据权利要求2所述的方法,其特征在于,嵌入网络Femb包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层、第四卷积层以及输出层。
4.根据权利要求1所述的方法,其特征在于,步骤S3中,全连接图GT=(V,E),V={v1,...,vN*K+r}表示图中的结点,E={eij;vi,vj∈V}表示图中的边,且
Figure FDA0003113663240000021
表示图中相邻结点vi和vj之间每一对应像素位置即vid∈R1*c与vjd∈R1*c之间的相似性;将步骤S2得到的特征图
Figure FDA0003113663240000022
作为图GT=(V,E)的初始结点特征
Figure FDA0003113663240000023
然后进行边特征初始化。
5.根据权利要求4所述的方法,其特征在于,边特征初始化具体为:
Figure FDA0003113663240000024
其中,yi和yj分别表示结点vi和vj的类别标签。
6.根据权利要求1所述的方法,其特征在于,步骤S4具体为:
S401、输入步骤S3中得到的图GT到边transformer图神经网络SGNN进行3层图更新,对于边transformer图神经网络SGNN中的每一层,根据边特征更新结点特征;
S402、将每一层更新后的结点特征输入到边Transformer模块更新边特征,通过边Transformer模块计算相邻结点特征间的差值特征图
Figure FDA0003113663240000025
然后将差值特征图
Figure FDA0003113663240000026
进行拆分,得到w*h个差值特征块组成的序列pl,引入可学习的位置编码
Figure FDA0003113663240000027
然后将序列pl和位置编码
Figure FDA0003113663240000028
进行级联得到序列
Figure FDA0003113663240000029
最后将序列
Figure FDA00031136632400000210
送入到边Transformer模块得到更新后的边特征
Figure FDA00031136632400000211
7.根据权利要求6所述的方法,其特征在于,边Transformer模块包含两个子层,每一个子层后接一个LN层进行标准化,第一个子层为自注意力层,第二个子层为前馈网络层;
第一个子层中,首先得到序列
Figure FDA00031136632400000212
的每个位置的查询向量
Figure FDA00031136632400000213
键向量
Figure FDA00031136632400000214
以及值向量
Figure FDA00031136632400000215
然后计算自注意力
Figure FDA00031136632400000216
将输入序列
Figure FDA00031136632400000217
和自注意力
Figure FDA00031136632400000218
做一个残差连接
Figure FDA0003113663240000031
第二个子层中,输出的边特征维度
Figure FDA0003113663240000032
为:
Figure FDA0003113663240000033
其中,MLP为一个包含两个全连接层的多层感知机。
8.根据权利要求1所述的方法,其特征在于,步骤S5中,查询结点vi的类别概率分布
Figure FDA0003113663240000034
计算如下:
Figure FDA0003113663240000035
其中,xi表示查询集中的样本,xj表示支持集中的样本,yj为支持样本xj的标签,
Figure FDA0003113663240000036
为最终边。
9.根据权利要求1所述的方法,其特征在于,步骤S6中,小样本分类损失Lfl
Figure FDA0003113663240000037
其中,Lce表示交叉熵损失,
Figure FDA0003113663240000038
为查询结点vi的类别概率分布,yj为支持样本xj的标签。
10.一种基于边Transformer图神经网络的小样本图像分类系统,其特征在于,包括:
采样模块,采样小样本学习任务T;
特征模块,将采样模块得到的小样本学习任务T中的每一个样本xi送入构建的嵌入网络Femb中,得到每一个样本的特征图fi
全连接模块,构建一个全连接图GT,将特征模块中每一个样本的特征图fi作为初始的结点特征
Figure FDA0003113663240000039
并根据查询样本的标签初始化边特征
Figure FDA00031136632400000310
神经网络模块,将全连接模块构建的全连接图GT输入到由L层边Transformer图神经网络构成的ETGNN中迭代进行结点特征更新,利用Transformer图神经网络的边transformer模块进行边特征更新,得到每一层的边特征
Figure FDA0003113663240000041
融合模块,对神经网络模块得到的L层的边特征
Figure FDA0003113663240000042
进行级联,然后输入到构建的边特征融合网络Ffus中,得到最终边
Figure FDA0003113663240000043
并根据最终边
Figure FDA0003113663240000044
以及支持样本的类别yj得到查询结点vi的类别概率分布
Figure FDA0003113663240000045
分类模块,根据融合模块得到的查询结点的类别概率分布
Figure FDA0003113663240000046
以及查询结点的标签yi计算小样本分类损失Lfl,端到端地训练嵌入网络Femb以及边Transformer图神经网络ETGNN,在测试时利用类别概率分布
Figure FDA0003113663240000047
对查询样本进行类别预测以实现小样本分类。
CN202110657352.4A 2021-06-11 2021-06-11 一种基于边Transformer图神经网络的小样本图像分类方法及系统 Active CN113378938B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110657352.4A CN113378938B (zh) 2021-06-11 2021-06-11 一种基于边Transformer图神经网络的小样本图像分类方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110657352.4A CN113378938B (zh) 2021-06-11 2021-06-11 一种基于边Transformer图神经网络的小样本图像分类方法及系统

Publications (2)

Publication Number Publication Date
CN113378938A true CN113378938A (zh) 2021-09-10
CN113378938B CN113378938B (zh) 2022-12-13

Family

ID=77574207

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110657352.4A Active CN113378938B (zh) 2021-06-11 2021-06-11 一种基于边Transformer图神经网络的小样本图像分类方法及系统

Country Status (1)

Country Link
CN (1) CN113378938B (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114119977A (zh) * 2021-12-01 2022-03-01 昆明理工大学 一种基于图卷积的Transformer胃癌癌变区域图像分割方法
CN114299535A (zh) * 2021-12-09 2022-04-08 河北大学 基于Transformer的特征聚合人体姿态估计方法
CN114898136A (zh) * 2022-03-14 2022-08-12 武汉理工大学 一种基于特征自适应的小样本图像分类方法
CN114299535B (zh) * 2021-12-09 2024-05-31 河北大学 基于Transformer的特征聚合人体姿态估计方法

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2012139813A1 (en) * 2011-04-12 2012-10-18 F. Hoffmann-La Roche Ag Connector device
CN111339883A (zh) * 2020-02-19 2020-06-26 国网浙江省电力有限公司 复杂场景下基于人工智能的变电站内异常行为识别与检测方法
CN111460097A (zh) * 2020-03-26 2020-07-28 华泰证券股份有限公司 一种基于tpn的小样本文本分类方法
CN111950596A (zh) * 2020-07-15 2020-11-17 华为技术有限公司 一种用于神经网络的训练方法以及相关设备
CN112070128A (zh) * 2020-08-24 2020-12-11 大连理工大学 一种基于深度学习的变压器故障诊断方法
CN112633403A (zh) * 2020-12-30 2021-04-09 复旦大学 一种基于小样本学习的图神经网络分类方法及装置

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2012139813A1 (en) * 2011-04-12 2012-10-18 F. Hoffmann-La Roche Ag Connector device
CN111339883A (zh) * 2020-02-19 2020-06-26 国网浙江省电力有限公司 复杂场景下基于人工智能的变电站内异常行为识别与检测方法
CN111460097A (zh) * 2020-03-26 2020-07-28 华泰证券股份有限公司 一种基于tpn的小样本文本分类方法
CN111950596A (zh) * 2020-07-15 2020-11-17 华为技术有限公司 一种用于神经网络的训练方法以及相关设备
CN112070128A (zh) * 2020-08-24 2020-12-11 大连理工大学 一种基于深度学习的变压器故障诊断方法
CN112633403A (zh) * 2020-12-30 2021-04-09 复旦大学 一种基于小样本学习的图神经网络分类方法及装置

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
XIBING ZUO ET AL.: "Graph inductive learning method for small sample classification of hyperspectral remote sensing images", 《EUROPEAN JOURNAL OF REMOTE SENSING》 *
刘颖: "基于小样本学习的图像分类技术综述", 《自动化学报》 *

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114119977A (zh) * 2021-12-01 2022-03-01 昆明理工大学 一种基于图卷积的Transformer胃癌癌变区域图像分割方法
CN114299535A (zh) * 2021-12-09 2022-04-08 河北大学 基于Transformer的特征聚合人体姿态估计方法
CN114299535B (zh) * 2021-12-09 2024-05-31 河北大学 基于Transformer的特征聚合人体姿态估计方法
CN114898136A (zh) * 2022-03-14 2022-08-12 武汉理工大学 一种基于特征自适应的小样本图像分类方法
CN114898136B (zh) * 2022-03-14 2024-04-19 武汉理工大学 一种基于特征自适应的小样本图像分类方法

Also Published As

Publication number Publication date
CN113378938B (zh) 2022-12-13

Similar Documents

Publication Publication Date Title
WO2021042828A1 (zh) 神经网络模型压缩的方法、装置、存储介质和芯片
WO2020238293A1 (zh) 图像分类方法、神经网络的训练方法及装置
CN107480261B (zh) 一种基于深度学习细粒度人脸图像快速检索方法
WO2023000574A1 (zh) 一种模型训练方法、装置、设备及可读存储介质
WO2021057056A1 (zh) 神经网络架构搜索方法、图像处理方法、装置和存储介质
JP2021524099A (ja) 異なるデータモダリティの統計モデルを統合するためのシステムおよび方法
WO2021022521A1 (zh) 数据处理的方法、训练神经网络模型的方法及设备
EP4002161A1 (en) Image retrieval method and apparatus, storage medium, and device
CN108875076B (zh) 一种基于Attention机制和卷积神经网络的快速商标图像检索方法
WO2022001805A1 (zh) 一种神经网络蒸馏方法及装置
CN110222718B (zh) 图像处理的方法及装置
WO2021051987A1 (zh) 神经网络模型训练的方法和装置
CN114398491A (zh) 一种基于知识图谱的语义分割图像实体关系推理方法
CN113378938B (zh) 一种基于边Transformer图神经网络的小样本图像分类方法及系统
CN112199532A (zh) 一种基于哈希编码和图注意力机制的零样本图像检索方法及装置
CN113378937B (zh) 一种基于自监督增强的小样本图像分类方法及系统
Chen et al. Binarized neural architecture search for efficient object recognition
CN116386899A (zh) 基于图学习的药物疾病关联关系预测方法及相关设备
CN111178196B (zh) 一种细胞分类的方法、装置及设备
CN115018039A (zh) 一种神经网络蒸馏方法、目标检测方法以及装置
CN115457332A (zh) 基于图卷积神经网络和类激活映射的图像多标签分类方法
CN116502181A (zh) 基于通道扩展与融合的循环胶囊网络多模态情感识别方法
WO2022063076A1 (zh) 对抗样本的识别方法及装置
CN114943017A (zh) 一种基于相似性零样本哈希的跨模态检索方法
Zhu et al. Local information fusion network for 3D shape classification and retrieval

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant