CN114741507B - 基于Transformer的图卷积网络的引文网络分类模型建立及分类 - Google Patents
基于Transformer的图卷积网络的引文网络分类模型建立及分类 Download PDFInfo
- Publication number
- CN114741507B CN114741507B CN202210306043.7A CN202210306043A CN114741507B CN 114741507 B CN114741507 B CN 114741507B CN 202210306043 A CN202210306043 A CN 202210306043A CN 114741507 B CN114741507 B CN 114741507B
- Authority
- CN
- China
- Prior art keywords
- network
- feature
- matrix
- node
- layer
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000013145 classification model Methods 0.000 title claims abstract description 14
- 238000005096 rolling process Methods 0.000 title claims description 18
- 238000012549 training Methods 0.000 claims abstract description 51
- 238000000034 method Methods 0.000 claims abstract description 25
- 238000012360 testing method Methods 0.000 claims abstract description 7
- 239000011159 matrix material Substances 0.000 claims description 62
- 238000013507 mapping Methods 0.000 claims description 12
- 238000010586 diagram Methods 0.000 claims description 6
- 238000004364 calculation method Methods 0.000 claims description 5
- 238000000605 extraction Methods 0.000 claims description 4
- 230000002452 interceptive effect Effects 0.000 claims description 3
- 230000008569 process Effects 0.000 claims description 3
- 238000012795 verification Methods 0.000 claims description 3
- 239000010410 layer Substances 0.000 description 47
- 241000689227 Cora <basidiomycete fungus> Species 0.000 description 9
- 230000009466 transformation Effects 0.000 description 8
- 238000002474 experimental method Methods 0.000 description 6
- 230000004913 activation Effects 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 230000006870 function Effects 0.000 description 3
- 230000004048 modification Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 206010012601 diabetes mellitus Diseases 0.000 description 1
- 238000001914 filtration Methods 0.000 description 1
- 238000009499 grossing Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000008520 organization Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 230000006916 protein interaction Effects 0.000 description 1
- 108090000623 proteins and genes Proteins 0.000 description 1
- 239000002356 single layer Substances 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 230000003595 spectral effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Databases & Information Systems (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Compression, Expansion, Code Conversion, And Decoders (AREA)
Abstract
本发明公开了一种基于Transformer的图卷积网络的引文网络分类模型建立方法,首先获取引文网络数据,引文网络数据包括确定节点的主体身份(论文、作者),收集结点的语料特征,确定节点的标签,确定节点间的关系,然后建立基于Transformer的图卷积网络模型,包括一个K层的简化图卷积网络模块,一个经过改造的Transformer编码器;然后利用简化图卷积网络对所有节点进行特征的卷积传播,利用Transformer编码器对训练集的所有节点的每层特征学习一个全局特征用以分类,最后利用训练好的Transformer编码器对测试结点进行分类。
Description
技术领域
本发明属于人工智能技术领域,具体涉及一种基于Transformer的图卷积网络的引文网络分类模型建立及分类方法。
背景技术
卷积神经网络(CNN)在计算机视觉中已经获得了广泛的应用,取得了非常出色的表现,尤其是图像这种欧式特征的数据,CNN中的卷积层通过学习多种不同的局部滤波器,通过滤波方式来对图像进行高层特征的提取。那么如何在图这种关系结构的数据上进行有效地特征提取就显得十分重要,类比图像上的卷积操作以及图信号处理,有了两种对图卷积的定义。一种是定义在频谱域的,例如ChebNet、GCN、SGC。另外一种是定义在空间域的,例如GarphSage、GAT。
图卷积操作的实质上是在图上进行特征平滑,根据图的结构信息,将邻近节点的特征尽可能地向同一方向平滑,所以随着图卷积网络的层数加深,所有节点的特征都会趋同,所以这就导致了目前大部分图卷积网络只能进行浅层学习,无法进行深层学习,但是如果只进行浅层学习,节点的特征不足以充分扩散到全图,无法充分利用整张图的结构信息。
发明内容
为了解决有技术中对卷积网络不能进行深层学习的技术问题,本发明的目的在于,提供一种基于Transformer的图卷积网络的引文网络分类模型建立及分类方法。
为了实现上述任务,本发明采用如下的技术解决方案:
一种基于Transformer的图卷积网络的引文网络分类模型建立方法,其特征在于,包括如下步骤:
步骤1:获取引文网络数据
引文网络数据包括确定节点的主体身份(论文、作者),收集结点的语料特征,确定节点的标签,确定节点间的关系。最终建立节点的特征矩阵X,节点的标签矩阵Y,以及节点关系图的邻接矩阵G,同时将数据分为训练集、验证集、测试集;
步骤2:建立基于Transformer的图卷积网络模型
所建立的基于Transformer的图卷积网络模型包括一个K层的简化图卷积网络模块,一个经过改造的Transformer编码器;然后利用简化图卷积网络对所有节点进行特征的卷积传播,利用Transformer编码器对训练集的所有节点的每层特征学习一个全局特征用以分类,最后利用训练好的Transformer编码器对测试结点进行分类;
其中,所述的基于Transformer的图卷积网络模型的训练模型包括如下子步骤:
步骤2.1:计算卷积传播矩阵S:
将简化图卷积网络对图上卷积的定义如式1所示:
式中,gθ'是一个卷积滤波器,X是输入图信号即节点特征,θ是可学习参数矩阵,是归一化图拉普拉斯矩阵,/>是/>的度矩阵,S是规范化图邻接矩阵;
考虑将其拓展为多层结构,且不使用非线性变换,则多层的卷积传播可表示为F=S…SXθ1…θK,然后,仅在卷积层进行特征提取,并不需要在每层进行训练学习参数,进一步假定θ1=…=θK=1,即
F=SKX (式2)
式中,SK是规范化图邻接矩阵的K次幂,X是输入节点特征矩阵计算传播矩阵其中/> 为/>的度矩阵;
步骤2.2:对特征矩阵X进行K次特征卷积传播,具体方法为:
以特征矩阵X作为输入,每层网络对输入作用一次S,并且当前层的输入为上一层网络的输出,并且为每层的输出增加一次标准化操作,将每层输出特征映射到同一分布,然后将每一层网络的输出都保存起来,为每个节点形成一个序列特征;
假定输入特征矩阵X的维度为n×d,那么最终简化卷积网络的输出F的维度为k×n×d。
步骤2.3:提取训练集特征Ftrain:
根据训练集节点的序号从上一步得到的特征矩阵F中将训练集特征全部提取出来,得到用于训练Transformer编码器的特征矩阵Ftrain,假定训练集大小为t,则Ftrain的大小为k×t×d。
步骤2.4:提取训练集标签Ytrain:
根据训练集节点的序号从标签矩阵中将训练集标签Ytrain提取出来;假定类别为c,训练集大小为t,则Ytrain大小为t×c;
步骤2.5:将Ftrain与Ytrain输入到Transformer编码器中学习全局特征,使用学习到的全局特征来进行最终节点类别的预测,具体方法是:
首先先将Ftrain经过一个MLP编码网络映射到一个低维空间,输出大小为一个可调节的超参数;
接着为每个节点的序列特征前增加一个分类头CLS TOKEN,这个分类头是一个全零特征,负责与序列中的其他特征交互学习,形成最终的全局特征;
经过MLP低维映射和增加CLS TOKEN后,节点特征变为式3所示:
Z0=【xCLSTOKEN,x1E,x2E,…,xkE】 (式3)
式中,xCLSTOKEN是在特征序列头部增加的初始化可学习全局特征,x1,x2,…,xk是上一步卷积过后每一层的输出特征,E是代表经过一层MLP进行低维编码;然后输入特征在经过多个多头注意力块(MSA)以及MLP块堆叠形成的Transformer编码器中进行学习;
具体地,输入特征会先经过一层LayerNorm,接着会应用一层多头注意力块(MSA),多头注意力的输出与LayerNorm之前的输入会经过一次残差连接,得到z′l,如式4所示:
z′l=MSA(LN(zl-1))+zl-1 (式4)
式中,LN()表示进行LayerNorm,MSA()表示作用一次多头注意力块,z′l表示当前层输出,zl-1表示当前层输入;
然后,z′l会再经过一次LayerNorm与MLP,最后MLP的输出与多头注意力块(MSA)的输出再做一次残差得到Zl,如式5所示:
zl=MLP(LN(z′l))+z′l (式5)
式中,LN()表示进行LayerNorm,MLP()表示经过一次MLP层,z′l表示当前层输入,Zl表示当前层输出;
最后对学习到的全局特征再作用一次LayerNorm后以其作为最终的分类特征,将其输入到一个MLP类别预测网络中得到预测类别;
然后计算预测类别与实际类别的交叉熵损失,在反向传播更新网络中的所有参数完成训练。
上述建立的基于Transformer的图卷积网络的引文网络分类模型的分类方法,其特征在于,包括如下步骤:
步骤一:从待测特征矩阵F中取一个待测特征序列Z=【x1,x2,…,xk】;
步骤二:初始化一个Transformer编码器,并加载已训练好的参数;
步骤三:将待测特征序列Z=【x1,x2,…,xk】输入到Transformer编码器中进行状态编码,得到全局状态特征Z0;
其中,类别预测过程包括以下步骤:
步骤a:将特征序列Z中的每个子特征输入到已训练好的MLP编码网络中得到其低维映射表示Z'=【x1E,x2E,…,xkE】;
步骤b:在经过低维编码后的低维特征序列Z'的头部增加一个全零特征CLSTOKEN,得到Transformer编码器的输入特征Zinput=【xCLSTOKEN,x1E,x2E,…,xkE】,并使用该特征学习一个全局分类特征;
步骤c:将输入特征Zinput=【xCLSTOKEN,x1E,x2E,…,xkE】输入到以训练并加载好的Transformer编码器中进行编码计算,得到Zoutput=【Z0,Z1,Z2…,Zk】,Z0为编码得到的全局特征;
步骤四:将上一步得到的全局状态特征Z0输入到已训练好的类别预测MLP网络中进行类别预测,得到最终分类结果。
本发明建立的基于Transformer的图卷积网络的引文网络分类模型及其分类方法,带来的技术创新在于:
1、在进行提取节点特征进行分类时,没有像目前其他技术一样只采用最后一次卷积传播后的特征作为最终分类特征,而是通过将每层卷积层的特征都提取出来,将其组成一个特征序列,通过对该特征序列进一步学习一个全局特征,最后应用该全局特征进行分类。因为使用到了每一层的特征信息,并未只使用最后一层已经过度平滑了的特征,因此,所建立的基于Transformer的图卷积网络的引文网络分类模型在加深时也不会出现性能下降。
2、建立的基于Transformer的图卷积网络采用了一种简化的图卷积定义方式,去掉了不同卷积层之间的非线性激活函数。因为通过实验发现,添加了非线性激活函数不但不能提高网络的性能,而且极大增加了算法的时间复杂度与空间复杂度,所以与其他方法相比在时间复杂度和空间复杂度上都有较大优势。
3、由于使用了全部特征进行全局特征的学习,避免了最后一层网络特征过度平滑的问题,因此可以进行深层学习,相较于其他技术的浅层学习,本申请的分类方法充分利用了整张图的结构信息,从而其性能与其他技术相比更好。
附图说明
图1是引文网络与结点分类任务的示意图;
图2是Transformer的结构图。
图3是基于Transformer的简化图卷积网络模型的结构图。
以下结合附图和实施例对本发明作进一步地详细说明。
具体实施方式
首先对本发明中出现的技术词语进行解释说明:
引文网络:是一种由论文、作者及其引用关系组成的数据集。这些论文/作者(节点)通过引用关系(边)相互连接组成,而这些论文/作者都有一个对应的类别标签,是一种图结构的数据集,也就是节点的组织方式是多对多的。一般引文网络的组织方式为两个部分:特征和图,也就是其连接关系组织为一张图,通常采用邻接矩阵或字典来进行存储,在实际使用时,如果是字典存储的,一般需要再进一步加工成邻接矩阵形式,另一部分就是节点的特征,一般存储为为一个一维向量,向量每个维度对应了字典中的某个单词,也就是节点本身是由一段文本进行描述,将文本与字典对应起来,一般组成一个one-hot向量进行存储。而边是没有特征的,这与知识图谱(多关系图)是不一样的,知识图谱中的边是有信息的,而引文网络中的边只是一种相互引用关系,并无实质特征。
语料特征:一段描述引文网络中节点的特征的文字,这段文字编码成一个只包含0和1的向量来表示,向量的长度为语料字典的大小,每个维度上的值指示该节点的特征描述文字中是否包含指向字典某个单词,如果包含为1,不包含为0。
节点分类:节点分类任务是根据图信息和节点自身特征信息,通过训练一个分类器对图中每个未标记节点预测出一个特定类别,例如在蛋白质相互作用网络中,通过给定的图数据和节点数据,需要给出每个节点可以分配几个基因本体类型。在引文网络中,给定作者节点或者文章节点,以及作者或文章的相互引用关系网,需要去对每篇文章或者每个作者节点去预测一个文章或作者类别或者他们的研究主题。
本实施例给出一种基于基于Transformer的图卷积网络的引文网络分类模型建立方法,包括如下步骤:
步骤1:获取引文网络数据,包括确定节点的主体身份(论文、作者),收集结点的语料特征,确定节点的标签,确定节点间的关系。最终建立节点的特征矩阵X,节点的标签矩阵Y,以及节点关系图的邻接矩阵G,同时将数据分为训练集、验证集、测试集;
步骤2:建立基于Transformer的图卷积网络模型,所述的基于Transformer的图卷积网络模型包括两个模块,一个K层的简化图卷积网络模块,一个经过改造的Transformer编码器。然后利用简化图卷积网络对所有节点进行特征的卷积传播,利用Transformer编码器对训练集的所有节点的每层特征学习一个全局特征用以分类,最后利用训练好的Transformer编码器对测试结点进行分类;
其中,所述的基于Transformer的图卷积网络模型的训练模型包括如下子步骤:
步骤2.1:计算卷积传播矩阵S;
将简化图卷积网络对图上卷积的定义为式1:
上式中,gθ'是一个卷积滤波器,X是输入图信号即节点特征,θ是可学习参数矩阵,是归一化图拉普拉斯矩阵,/>是/>的度矩阵,S是规范化图邻接矩阵;
考虑将其拓展为多层结构,且不使用非线性变换,则多层的卷积传播可表示为F=S...SXθ1...θK,然后,我们只是在卷积层进行特征提取,并不需要在每层进行训练学习参数,进一步假定θ1=...=θK=1,即有:
F=SKX 式2
式中,SK是规范化图邻接矩阵的K次幂,X是输入节点特征矩阵计算传播矩阵其中,/> 为/>的度矩阵;S是规范化图邻接矩阵,是归一化拉普拉斯矩阵。
步骤2.2:对特征矩阵X进行K次特征卷积传播:
以特征矩阵X作为输入,每层网络对输入作用一次S,并且当前层的输入为上一层网络的输出,并且为每层的输出增加一次标准化操作,将每层输出特征映射到同一分布,然后将每一层网络的输出都保存起来,为每个节点形成一个序列特征。假定输入特征矩阵X的维度为n×d,那么最终简化卷积网络的输出F的维度为k×n×d。
步骤2.3:提取训练集特征矩阵Ftrain:
根据训练集节点的序号从上一步得到的特征矩阵中将训练集特征全部提取出来,得到用于训练Transformer编码器的特征矩阵Ftrain,假定训练集大小为t,则Ftrain的大小为k×t×d。
步骤2.4:提取训练集标签Ytrain;
根据训练集节点的序号从标签矩阵中将训练集标签Ytrain提取出来。假定类别为c,训练集大小为t,则练集标签Ytrain大小为t×c。
步骤2.5:将Ftrain与Ytrain输入到Transformer编码器中学习全局特征,使用学习到的全局特征来进行最终节点类别的预测,具体方法是:
首先将Ftrain经过一个MLP编码网络映射到一个低维空间,输出大小为一个可调节的超参数。
接着为每个节点的序列特征前增加一个分类头CLS TOKEN,这个分类头是一个全零特征,负责与序列中的其他特征交互学习,形成最终的全局特征。但不像标准的Transformer那样还需要进行位置编码,在这一步去掉了位置编码,因为经过实验发现位置编码时不必要的,如果增加的位置编码反而会带来性能下降。
经过MLP低维映射和增加CLS TOKEN后,节点特征变为式3所示:
Z0=【xCLSTOKEN,x1E,x2E,…,xkE】 (式3)
然后输入特征在经过多个多头注意力块(MSA)以及MLP块堆叠形成的Transformer编码器中进行学习。
具体地,输入特征会先经过一层LayerNorm,接着会应用一层多头注意力块(MSA),多头注意力的输出与LayerNorm之前的输入会经过一次残差连接,得到z′l,如式4:
z′l=MSA(LN(zl-1))+zl-1 (式4)
然后z′l会再经过一次LayerNorm与MLP,最后MLP的输出与多头注意力的输出在做一次残差得到Zl,如式5所示:
zl=MLP(LN(z′l))+z′l (式5)
式中,LN()表示进行LayerNorm,MLP()表示经过一次MLP层,z′l表示当前层输入,Zl表示当前层输出;
最后对学习到的全局特征再作用一次LayerNorm后以其作为最终的分类特征,将其输入到一个MLP类别预测网络中得到预测类别。
然后计算预测类别与实际类别的交叉熵损失,在反向传播更新网络中的所有参数完成训练。
上述建立的基于Transformer的图卷积网络的引文网络分类模型的分类方法,包括如下步骤:
步骤一:从待测特征矩阵F中取一个待测特征序列Z=【x1,x2,…,xk】;
步骤二:初始化一个Transformer编码器,并加载已训练好的参数;
步骤三:将待测特征序列Z=【x1,x2,…,xk】输入到Transformer编码器中进行状态编码,得到全局状态特征Z0。
其中,类别预测过程包括以下步骤:
步骤a:将特征序列Z中的每个子特征输入到已训练好的MLP编码网络中得到其低维映射表示Z'=【x1E,x2E,…,xkE】;
步骤b:在经过低维编码后的低维特征序列Z'的头部增加一个全零特征CLSTOKEN,得到Transformer编码器的输入特征Zinput=【xCLSTOKEN,x1E,x2E,…,xkE】,其目的在于使用该特征学习一个全局分类特征。
步骤c:将输入特征Zinput=【xCLSTOKEN,x1E,x2E,…,xkE】输入到以训练并加载好的Transformer编码器中进行编码计算得到Zoutput=【Z0,Z1,Z2…,Zk】,Z0为编码得到的全局状态特征。
步骤四:将上一步得到的全局状态特征Z0输入到已训练好的类别预测MLP网络中进行类别预测,得到最终分类结果。
在上述实施例的基础上,为了验证上述实施例给出基于Transformer的图卷积网络的引文网络分类模型建立及其分类方法,发明人给出了以下的实验例。
本实验例中,采用的引文网络数据集分别是cora、citeseer以及pubmed这三个数据集。
其中,Cora数据集包含了一共七类,2708篇机器学习出版物,每篇论文由一个长度为1433单词热向量表示。Citeseer数据集由六类3327篇科学论文组成,每篇论文由一个长度为3703的单词热向量表示。Pubmed数据集由3类19717篇糖尿病相关出版物构成,每篇论文由一个词频-逆文档频率(TF-IDF)向量表示。这几个数据集遵循了主流的半监督划分。数据集的划分数量情况见下表1。
表1:数据集信息表
在实验例中,Transformer编码器去掉了其位置编码,为了证明所进行的修改(即Transformer编码器去掉了其位置编码)是有效的,在同样的网络参数下,发明人测试了这三个数据集上分别增加位置编码和不增加位置编码的准确率,如表2所示,去除了位置编码后,三个数据集上的准确率均提升了5-10个百分点,从而说明其修改是有效的。
表2:增添位置编码性能对比表
cora | citeseer | pubmed | |
增加位置编码 | 0.773 | 0.633 | 0.751 |
去除位置编码 | 0.827 | 0.718 | 0.800 |
本实施例中,将特征卷积网络部分去掉了非线性变换,为了说明所进行的去掉非线性变换是有用的,发明人进行了非线性变换进行实验。如表3所示,发明人尝试了网络深度为1-15,增加非线性变换与未增加非线性变换两种情况下的实验,这里增加的非线性变换为Relu激活函数,在Cora数据集上,网络层数为15时,如果增加非线性变换,网络的准确率只有0.377,如果不增加则准确率为0.809,而pubmed与citeseer数据集上,当网络层数为15时,由于内存溢出已经不能正常在12G显存的单卡上进行训练了,而即使在单层网络情况下,例如在1层网络情况下,Cora数据集上比不增加降低了0.262,pubmed数据集上比不增加降低了0.07,Citeseer数据集上也比不增加降低了0.481。
表3:增加非线性变换与去除后性能对比
为了说明本实验例所涉及算法相较于其他算法的一些优越性,发明人在Cora、Citeseer、Pubmed等三个数据集上进行了对比实验,对比了GCN、GAT、FastGCN、SGC等几个主流算法与本实验例的算法的性能以及训练时间。
表4给出了其他算法与本实验例给出算法的准确率指标对比,表5给出了其他算法与本实施例算法的训练时间对比。在Cora数据集上GAT的准确率最高达到了0.830,本实验例提出的算法准确率只比其低了0.003,但是本实施例的训练时间为0.45s,远低于GAT的63.1s,在Citeseer数据集上,也是GAT的性能最好,达到了0.725,本实验例的准确率为0.718,只比其低了0.007,但是本实验例的训练时间为1.2s,也远远低于GAT的118.1s,最后在Pubmed数据集上,本实验中,所采用的算法性能达到了最佳,准确率为0.800,并且训练时间也十分少,仅为1.05s。综上,本实验例给出的算法虽然在某些数据集上的性能不是最好,但是其算法的消耗时间是远低于性能最好算法的,并且性能与最好算法的性能差距也控制在了0.01以内。
表4:准确率对比
cora | citeseer | pubmed | |
GCN | 0.815 | 0.703 | 0.790 |
GAT | 0.830 | 0.725 | 0.790 |
FastGCN | 0.798 | 0.686 | 0.774 |
SGC | 0.810 | 0.719 | 0.789 |
OUR | 0.827 | 0.718 | 0.800 |
表5:训练时间对比
cora | citeseer | pubmed | |
GCN | 0.49 | 0.59 | 8.31 |
GAT | 63.1 | 118.1 | 121.74 |
FastGCN | 2.47 | 3.96 | 1.77 |
SGC | 0.13 | 0.14 | 0.29 |
OUR | 0.45 | 1.20 | 1.05 |
Claims (2)
1.一种基于Transformer的图卷积网络的引文网络分类模型建立方法,其特征在于,包括如下步骤:
步骤1:获取引文网络数据
引文网络数据包括确定节点的主体身份,收集结点的语料特征,确定节点的标签,确定节点间的关系;最终建立节点的特征矩阵X,节点的标签矩阵Y,以及节点关系图的邻接矩阵G,同时将数据分为训练集、验证集、测试集;
步骤2:建立基于Transformer的图卷积网络模型
所建立的基于Transformer的图卷积网络模型,包括一个K层的简化图卷积网络模块,一个经过改造的Transformer编码器;然后利用简化图卷积网络对所有节点进行特征的卷积传播,利用Transformer编码器对训练集的所有节点的每层特征学习一个全局特征用以分类,最后利用训练好的Transformer编码器对测试结点进行分类;
其中,所述的基于Transformer的图卷积网络模型的训练模型包括如下子步骤:
步骤2.1:计算卷积传播矩阵S:
将简化图卷积网络对图上卷积的定义如式1所示:
式中,gθ'是一个卷积滤波器,X是输入图信号即节点特征,θ是可学习参数矩阵,是归一化图拉普拉斯矩阵,/>是/>的度矩阵,S是规范化图邻接矩阵;
考虑将其拓展为多层结构,且不使用非线性变换,则多层的卷积传播可表示为F=S…SXθ1…θK,然后,仅在卷积层进行特征提取,并不需要在每层进行训练学习参数,进一步假定θ1=…=θK=1,即有:
F=SKX (式2)
式中,SK是规范化图邻接矩阵的K次幂,X是输入节点特征矩阵;
计算传播矩阵其中,/> 为/>的度矩阵;S是规范化图邻接矩阵,/>是归一化拉普拉斯矩阵;
步骤2.2:对特征矩阵X进行K次特征卷积传播,具体方法为:
以特征矩阵X作为输入,每层网络对输入作用一次S,并且当前层的输入为上一层网络的输出,并且为每层的输出增加一次标准化操作,将每层输出特征映射到同一分布,然后将每一层网络的输出都保存起来,为每个节点形成一个序列特征;
假定输入特征矩阵X的维度为n×d,那么最终简化卷积网络的输出F的维度为k×n×d;
步骤2.3:提取训练集特征矩阵Ftrain:
根据训练集节点的序号从上一步得到的特征矩阵F中将训练集特征全部提取出来,得到用于训练Transformer编码器的特征矩阵Ftrain,假定训练集大小为t,则Ftrain的大小为k×t×d;
步骤2.4:提取训练集标签Ytrain:
根据训练集节点的序号从标签矩阵中将训练集标签Ytrain提取出来;假定类别为c,训练集大小为t,则训练集标签Ytrain大小为t×c;
步骤2.5:将Ftrain与Ytrain输入到Transformer编码器中学习全局特征,使用学习到的全局特征来进行最终节点类别的预测,具体方法是:
首先将Ftrain经过一个MLP编码网络映射到一个低维空间,输出大小为一个可调节的超参数;
接着为每个节点的序列特征前增加一个分类头CLS TOKEN,这个分类头是一个全零特征,负责与序列中的其他特征交互学习,形成最终的全局特征;
经过MLP低维映射和增加CLS TOKEN后,节点特征变为式3所示:
Z0=【xCLSTOKEN,x1E,x2E,…,xkE】 (式3)
式中,xCLSTOKEN是在特征序列头部增加的初始化可学习全局特征,x1,x2,…,xk是上一步卷积过后每一层的输出特征,E是代表经过一层MLP进行低维编码;
然后输入特征在经过多个多头注意力块(MSA)以及MLP块堆叠形成的Transformer编码器中进行学习,具体方法为:
输入特征会先经过一层LayerNorm,接着会应用一层多头注意力块(MSA),多头注意力的输出与LayerNorm之前的输入会经过一次残差连接,得到z′l,如式4所示:
z′l=MSA(LN(zl-1))+zl-1 (式4)
式中,LN()表示进行LayerNorm,MSA()表示作用一次多头注意力块,z′l表示当前层输出,Zl-1表示当前层输入;
然后,z′l会再经过一次LayerNorm与MLP,最后MLP的输出与多头注意力块(MSA)的输出再做一次残差得到Zl,如式5所示:
zl=MLP(LN(z′l))+z′l (式5)
式中,LN()表示进行LayerNorm,MLP()表示经过一次MLP层,z′l表示当前层输入,Zl表示当前层输出;
最后对学习到的全局特征再作用一次LayerNorm后以其作为最终的分类特征,将其输入到一个MLP类别预测网络中得到预测类别;
然后计算预测类别与实际类别的交叉熵损失,在反向传播更新网络中的所有参数完成训练。
2.权利要求1所建立的基于Transformer的图卷积网络的引文网络分类模型的分类方法,其特征在于,包括如下步骤:
步骤一:从待测特征矩阵F中取一个待测特征序列Z=【x1,x2,…,xk】;
步骤二:初始化一个Transformer编码器,并加载已训练好的参数;
步骤三:将待测特征序列Z=【x1,x2,…,xk】输入到Transformer编码器中进行状态编码,得到全局状态特征Z0;
其中,类别预测过程包括以下步骤:
步骤a:将特征序列Z中的每个子特征输入到已训练好的MLP编码网络中得到其低维映射表示Z'=【x1E,x2E,…,xkE】;
步骤b:在经过低维编码后的低维特征序列Z'的头部增加一个全零特征CLS TOKEN,得到Transformer编码器的输入特征Zinput=【xCLSTOKEN,x1E,x2E,…,xkE】,并使用该特征学习一个全局分类特征;
步骤c:将输入特征Zinput=【xCLSTOKEN,x1E,x2E,…,xkE】输入到以训练并加载好的Transformer编码器中进行编码计算,得到Zoutput=【Z0,Z1,Z2…,Zk】,Z0为编码得到的全局状态特征;
步骤四:将上一步得到的全局状态特征Z0输入到已训练好的类别预测MLP网络中进行类别预测,得到最终分类结果。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210306043.7A CN114741507B (zh) | 2022-03-25 | 2022-03-25 | 基于Transformer的图卷积网络的引文网络分类模型建立及分类 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210306043.7A CN114741507B (zh) | 2022-03-25 | 2022-03-25 | 基于Transformer的图卷积网络的引文网络分类模型建立及分类 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114741507A CN114741507A (zh) | 2022-07-12 |
CN114741507B true CN114741507B (zh) | 2024-02-13 |
Family
ID=82276441
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210306043.7A Active CN114741507B (zh) | 2022-03-25 | 2022-03-25 | 基于Transformer的图卷积网络的引文网络分类模型建立及分类 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114741507B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116821452B (zh) * | 2023-08-28 | 2023-11-14 | 南京邮电大学 | 一种图节点分类模型训练方法、图节点分类方法 |
CN117315194B (zh) * | 2023-09-27 | 2024-05-28 | 南京航空航天大学 | 面向大型飞机外形的三角网格表征学习方法 |
CN118233035B (zh) * | 2024-05-27 | 2024-08-06 | 烟台大学 | 一种基于图卷积倒置Transformer的多频带频谱预测方法及系统 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109977223A (zh) * | 2019-03-06 | 2019-07-05 | 中南大学 | 一种融合胶囊机制的图卷积网络对论文进行分类的方法 |
JP2020205029A (ja) * | 2019-06-17 | 2020-12-24 | 大連海事大学 | ブロードラーニングシステムに基づく高速ネットワーク表現学習の方法 |
CN114119977A (zh) * | 2021-12-01 | 2022-03-01 | 昆明理工大学 | 一种基于图卷积的Transformer胃癌癌变区域图像分割方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11868730B2 (en) * | 2020-09-23 | 2024-01-09 | Jingdong Digits Technology Holding Co., Ltd. | Method and system for aspect-level sentiment classification by graph diffusion transformer |
-
2022
- 2022-03-25 CN CN202210306043.7A patent/CN114741507B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109977223A (zh) * | 2019-03-06 | 2019-07-05 | 中南大学 | 一种融合胶囊机制的图卷积网络对论文进行分类的方法 |
JP2020205029A (ja) * | 2019-06-17 | 2020-12-24 | 大連海事大学 | ブロードラーニングシステムに基づく高速ネットワーク表現学習の方法 |
CN114119977A (zh) * | 2021-12-01 | 2022-03-01 | 昆明理工大学 | 一种基于图卷积的Transformer胃癌癌变区域图像分割方法 |
Non-Patent Citations (2)
Title |
---|
基于图卷积网络和自编码器的半监督网络表示学习模型;王杰;张曦煌;;模式识别与人工智能(第04期);全文 * |
通过细粒度的语义特征与Transformer丰富图像描述;王俊豪;罗轶凤;;华东师范大学学报(自然科学版)(第05期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN114741507A (zh) | 2022-07-12 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CA3085033C (en) | Methods and systems for multi-label classification of text data | |
Yang et al. | Deep transfer learning for military object recognition under small training set condition | |
CN114741507B (zh) | 基于Transformer的图卷积网络的引文网络分类模型建立及分类 | |
Mariet et al. | Diversity networks: Neural network compression using determinantal point processes | |
Hassan et al. | Efficient deep learning model for text classification based on recurrent and convolutional layers | |
CN110969020B (zh) | 基于cnn和注意力机制的中文命名实体识别方法、系统及介质 | |
Chen et al. | Big data deep learning: challenges and perspectives | |
CN108229582A (zh) | 一种面向医学领域的多任务命名实体识别对抗训练方法 | |
CN106650813A (zh) | 一种基于深度残差网络和lstm的图像理解方法 | |
Rae et al. | Fast parametric learning with activation memorization | |
Wan et al. | A hybrid neural network-latent topic model | |
CN110046249A (zh) | 胶囊网络的训练方法、分类方法、系统、设备及存储介质 | |
WO2017193685A1 (zh) | 社交网络中数据的处理方法和装置 | |
CN109614611B (zh) | 一种融合生成非对抗网络与卷积神经网络的情感分析方法 | |
CN107491782A (zh) | 利用语义空间信息的针对少量训练数据的图像分类方法 | |
Grzegorczyk | Vector representations of text data in deep learning | |
Korshunova et al. | Discriminative topic modeling with logistic LDA | |
Glauner | Comparison of training methods for deep neural networks | |
CN114048729A (zh) | 医学文献评价方法、电子设备、存储介质和程序产品 | |
Aich et al. | Convolutional neural network-based model for web-based text classification. | |
Khayyat et al. | A deep learning based prediction of arabic manuscripts handwriting style. | |
Liebenwein et al. | Sparse flows: Pruning continuous-depth models | |
Zhang et al. | Cosine: compressive network embedding on large-scale information networks | |
CN113590748B (zh) | 基于迭代网络组合的情感分类持续学习方法及存储介质 | |
Vega et al. | Dynamic neural networks for text classification |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |