CN116821776A - 一种基于图自注意力机制的异质图网络节点分类方法 - Google Patents
一种基于图自注意力机制的异质图网络节点分类方法 Download PDFInfo
- Publication number
- CN116821776A CN116821776A CN202311099604.1A CN202311099604A CN116821776A CN 116821776 A CN116821776 A CN 116821776A CN 202311099604 A CN202311099604 A CN 202311099604A CN 116821776 A CN116821776 A CN 116821776A
- Authority
- CN
- China
- Prior art keywords
- node
- self
- attention
- graph
- heterogeneous graph
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 55
- 230000007246 mechanism Effects 0.000 title claims abstract description 34
- 239000011159 matrix material Substances 0.000 claims abstract description 57
- 238000012549 training Methods 0.000 claims abstract description 26
- 238000012360 testing method Methods 0.000 claims abstract description 25
- 238000012795 verification Methods 0.000 claims abstract description 18
- 230000006870 function Effects 0.000 claims description 21
- 230000008569 process Effects 0.000 claims description 17
- 239000013598 vector Substances 0.000 claims description 12
- 238000010606 normalization Methods 0.000 claims description 11
- 238000004364 calculation method Methods 0.000 claims description 7
- 238000006243 chemical reaction Methods 0.000 claims description 7
- 230000002779 inactivation Effects 0.000 claims description 7
- 230000004913 activation Effects 0.000 claims description 5
- 230000004931 aggregating effect Effects 0.000 claims description 3
- 238000003491 array Methods 0.000 claims description 3
- 230000017105 transposition Effects 0.000 claims description 3
- 238000010586 diagram Methods 0.000 description 5
- 238000007796 conventional method Methods 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 2
- 238000003745 diagnosis Methods 0.000 description 2
- 201000010099 disease Diseases 0.000 description 2
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 2
- 238000012502 risk assessment Methods 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 230000002776 aggregation Effects 0.000 description 1
- 238000004220 aggregation Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 150000001875 compounds Chemical class 0.000 description 1
- 238000007418 data mining Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 229940079593 drug Drugs 0.000 description 1
- 239000003814 drug Substances 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000000547 structure data Methods 0.000 description 1
- 239000000126 substance Chemical group 0.000 description 1
- 208000024891 symptom Diseases 0.000 description 1
Landscapes
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明提供了交通流预测技术领域的一种基于图自注意力机制的异质图网络节点分类方法,包括:步骤S1、获取大量的交通异质图网络的数据集并划分为训练集、验证集和测试集,从训练集、验证集和测试集中提取交通异质图网络的节点特征矩阵和邻接矩阵集合;步骤S2、创建一异质图自注意力网络模型;步骤S3、利用训练集对异质图自注意力网络模型进行训练,利用验证集对训练后的异质图自注意力网络模型进行验证;步骤S4、利用测试集对异质图自注意力网络模型进行测试,并不断优化超参数;步骤S5、利用异质图自注意力网络模型进行交通异质图网络的节点分类,进而进行交通流预测。本发明的优点在于:极大的提升了交通流预测的准确率。
Description
技术领域
本发明涉及交通流预测技术领域,特别指一种基于图自注意力机制的异质图网络节点分类方法。
背景技术
图神经网络(Graph Neural Networks,简称GNNs)是一类用于图数据挖掘的深度学习方法,被广泛应用于众多领域并且取得了很好的成果。在异质图网络上进行节点分类是GNNs的一项重要任务,异质图网络指由不同类型的节点和边(关系)组成的图网络,存在于许多现实世界的场景中,如社交网络中的用户和用户之间的多种关系,化合物分子中不同类型的原子和化学键等。异质图网络节点分类的目标是将所有节点分类到对应的类别中,从而更好地理解和学习异质图网络的结构和特征。
异质图网络的节点分类可应用在不同领域,如金融风险评估、推荐系统、医疗诊断等。在金融风险评估领域,可以使用异质图网络表示用户、资产和交易等信息,并通过节点分类来评价客户的信用等级和风险水平;在推荐系统领域,可以使用异质图网络表示用户、商品和用户商品交互信息,并通过节点分类来得到用户的兴趣和购买行为;在医疗诊断领域,可以使用异质图网络表示疾病、症状、药物等信息,并通过节点分类来预测疾病的类型以及严重程度。异质图网络的节点分类具有现实意义,可以更好地帮助我们理解及分析复杂的图结构数据,从而在多个领域实现更准确的预测。
由于异质图网络中的节点和边具有不同的类型,因此在进行节点分类时,不仅需要考虑节点的特征,同时需要考虑节点间复杂的异构信息。例如,在社交网络中,用户节点可能具有不同的类型以及节点间存在不同的联系等异构信息,而节点特征可能具有如年龄、性别、职业、爱好等属性,这些属性和异构信息可以作为图的特征输入到GNNs中,以帮助提高分类的准确性。
异质图网络节点分类的一大难点是如何使用其丰富的异构信息提高分类的准确率,传统方法通常是使用异质图网络上的元路径来定义不同类型节点之间的关系,并利用元路径推导出节点之间的相似性,然后使用GNNs对节点进行编码和分类,但传统方法无法捕捉到异质图网络中节点的高阶语义信息,无法学习到元路径以外的一些节点特征表示信息,导致使用传统方法进行交通流预测时,预测(节点分类)的准确率不尽如人意。
因此,如何提供一种基于图自注意力机制的异质图网络节点分类方法,实现提升交通流预测的准确率,成为一个亟待解决的技术问题。
发明内容
本发明要解决的技术问题,在于提供一种基于图自注意力机制的异质图网络节点分类方法,实现提升交通流预测的准确率。
本发明是这样实现的:一种基于图自注意力机制的异质图网络节点分类方法,包括如下步骤:
步骤S1、获取大量的交通异质图网络的数据集,按预设比例将所述数据集划分为训练集、验证集和测试集,从所述训练集、验证集和测试集中分别提取交通异质图网络的节点特征矩阵和邻接矩阵集合;
步骤S2、基于全局自注意力模块、图自注意力模块以及输出模块创建一异质图自注意力网络模型;
步骤S3、利用所述训练集对异质图自注意力网络模型进行训练,利用所述验证集对训练后的异质图自注意力网络模型进行验证;
步骤S4、利用所述测试集对验证后的异质图自注意力网络模型进行测试,并不断优化所述异质图自注意力网络模型的超参数;
步骤S5、利用测试后的所述异质图自注意力网络模型进行交通异质图网络的节点分类,进而进行交通流预测。
进一步的,所述步骤S1中,所述预设比例为2:1:7。
进一步的,所述步骤S1中,所述节点特征矩阵为:
X∈RN×d;
所述邻接矩阵集合为不同类型边的邻接矩阵集合,公式为:
;
其中,X表示节点特征;R表示实数;N表示节点数量;d表示节点特征的输入维度;A表示邻接矩阵;K表示异质图的边的类型数;k表示邻接矩阵编号。
进一步的,所述步骤S2中,所述全局自注意力模块用于学习交通异质图网络中各节点在全局的节点特征依赖和节点特征表示;
所述全局自注意力模块的学习过程为:
S211、将所述节点特征矩阵X分别通过三个可学习的矩阵WQ、WK、WV投影为Q、K、V:
Q=XWQ,K=XWK,V=XWV;
其中,WQ∈Rd×dk;WK∈Rd×dk;WV∈Rd×dv;dk=dv=d;
S212、对所述Q、K、V应用归一化的点乘注意力机制计算自注意力矩阵SAttn:
;
其中,softmax()表示归一化指数函数;T表示矩阵转置操作;
S213、并行执行多次归一化的点乘注意力机制,把计算得到的各所述自注意力矩阵SAttn相加取均值,得到节点嵌入XMHead:
;
其中,XMHead∈RN×d,表示经过多头注意力机制学习得到的节点嵌入;Head表示多头注意力机制的头数;W0∈Rd×dv;
S214、对所述节点嵌入XMHead与Q做残差连接后进行归一化,得到节点嵌入XN1:
XN1=Norm(Q+XMHead(Q,K,V));
其中,XN1∈RN×d,表示经过第一次归一化后得到的节点嵌入;Norm()表示归一化函数;
S215、将所述节点嵌入XN1输送到由两层线性连接层组成的前馈网络,并在两个所述线性连接层之间使用激活函数Relu来增加全局自注意力模块的非线性,得到节点嵌入XFFN:
XFFN=Linear(Relu(Linear(XN1)));
其中,XFFN∈RN×d,表示经过前馈网络后得到的节点嵌入;Linear()表示线性连接层;
S216、对所述节点嵌入XFFN与XN1做残差连接后进行归一化,得到节点嵌入XN2:
XN2=Norm(XN1+XFFN);
S217、对所述节点特征矩阵X和节点嵌入XN2进行拼接,得到节点特征表示XG:
XG=X‖XN2;
其中,XG∈RN×2d;‖表示拼接操作。
进一步的,所述步骤S2中,所述图自注意力模块用于学习交通异质图网络中不同类型边和节点特征的表示;
所述图自注意力模块的学习过程为:
S221、把不同类型边所构成的邻接矩阵A聚合在一起,得到新的邻接矩阵AC:
AC=Conv(A;WC)=AWC;
其中,AC∈RN×N;Conv()表示卷积函数;WC∈RK×1×1,表示可学习的参数矩阵;
S222、在所述邻接矩阵AC、节点特征表示XG的基础上,利用图卷积层学习交通异质图网络的节点以及其一阶邻居的特征信息,得到节点嵌入XC:
XC=Relu(GraphConv(XG;AC))=Relu(ACXGW);
其中,XC∈RN×dout,表示经过图卷积层学习得到的节点嵌入;dout表示输出的嵌入维度;GraphConv()表示图卷积操作;W∈R2d×dout,表示图卷积的权重矩阵;
S223、给定节点嵌入XC=[x1,x2…xN]T∈RN×dout,xN∈Rdout,表示节点N的特征表示;对于存在连接边的节点i和节点j,使用可学习参数Wq、Wk、bq、bk,将节点i的特征xi和节点j的特征xj分别转化为qi和kj:
qi=Wqxi+bq;
kj=Wkxj+bk;
其中,qi∈Rdout,kj∈Rdout,均为向量;
S224、将所述邻接矩阵AC通过可学习参数We、be转换为边缘特征eij,将所述边缘特征eij加入向量kj,得到向量kj’:
eij=WeAij+be;
kj’=kj+eij;
其中,Aij为邻接矩阵AC中的元素值,表示节点i和节点j之间存在相连的边;
S225、计算从节点j到节点i的每一条边的归一化点乘注意力αij:
;
;
其中,exp()表示以自然常数e为底的指数函数;N(i)表示节点i基于邻接矩阵AC的一阶邻居节点;
S226、通过可学习参数Wv、bv将节点j的特征xj转换为vj:
vj=Wvxj+bv;
其中,vj∈Rdout;
S227、基于所述vj、αij、eij计算多头注意力,得到节点嵌入zi:
;
S228、对所述节点嵌入zi引入门控单元Gate以及残差连接,得到节点嵌入:
ri=Wrxi+br;
di=zi‖ri‖(zi-ri);
;
;
其中,Wr、br、Wg均为可学习参数,且Wg∈R3dout;i表示节点编号;T表示转置操作;‖表示拼接操作;d表示拼接操作后得到的矢量;
S229、对所述节点嵌入进行归一化,得到节点嵌入Zi:
;
其中,Zi∈Rdout;
S230、重复两次S221-S229的学习过程,在经过所述图自注意力模块的学习后,获得所有节点最终的节点嵌入Z,Z∈RN×dout。
进一步的,所述步骤S2中,所述输出模块用于预测节点类别;
所述输出模块的计算过程为:
将所述节点嵌入Z输入两个全连接层和softmax函数得到预测的节点类别P:
P=softmax(Linear(Linear(Z)));
其中,P∈R1×n,n表示节点类别数。
进一步的,所述步骤S4中,所述超参数至少包括随机失活率、权值衰减率以及学习率。
本发明的优点在于:
通过获取大量的交通异质图网络的数据集并划分为训练集、验证集和测试集,从训练集、验证集和测试集中分别提取异质图网络的节点特征矩阵和邻接矩阵集合;基于全局自注意力模块、图自注意力模块以及输出模块创建一异质图自注意力网络模型,利用训练集对异质图自注意力网络模型进行训练,利用验证集对训练后的异质图自注意力网络模型进行验证,利用测试集对验证后的异质图自注意力网络模型进行测试,并不断优化异质图自注意力网络模型的超参数,最后利用测试后的异质图自注意力网络模型进行交通流预测;由于全局自注意力模块用于学习交通异质图网络中各节点在全局的节点特征依赖和节点特征表示,图自注意力模块用于学习交通异质图网络中不同类型边和节点特征的表示,在整个学习过程中不需要使用元路径,并能够更好学习交通异质图网络丰富的特征信息和高阶语义信息,具有更强大的异质图网络的节点特征学习能力,进而极大的提升了交通流预测的准确率。
附图说明
下面参照附图结合实施例对本发明作进一步的说明。
图1是本发明一种基于图自注意力机制的异质图网络节点分类方法的流程图。
图2是本发明异质图自注意力网络模型的结构示意图。
图3是本发明全局自注意力模块中多头注意力的结构示意图。
图4是本发明图转换注意力层的结构示意图。
具体实施方式
本申请实施例中的技术方案,总体思路如下:创建由全局自注意力模块、图自注意力模块以及输出模块组成的异质图自注意力网络模型,全局自注意力模块用于学习交通异质图网络中各节点在全局的节点特征依赖和节点特征表示,图自注意力模块用于学习交通异质图网络中不同类型边和节点特征的表示,在整个学习过程中不需要使用元路径,并能够更好学习交通异质图网络丰富的特征信息和高阶语义信息,具有更强大的异质图网络的节点特征学习能力,以提升交通流预测的准确率。
请参照图1至图4所示,本发明一种基于图自注意力机制的异质图网络节点分类方法的较佳实施例,包括如下步骤:
步骤S1、获取大量的交通异质图网络的数据集,按预设比例将所述数据集划分为训练集、验证集和测试集,从所述训练集、验证集和测试集中分别提取交通异质图网络的节点特征矩阵和邻接矩阵集合;具体实施时,所述数据集可选取交通异质图网络的公共基准数据集ACM、DBLP和IMDB;
步骤S2、基于全局自注意力模块、图自注意力模块以及输出模块创建一异质图自注意力网络模型;
步骤S3、利用所述训练集对异质图自注意力网络模型进行训练,利用所述验证集对训练后的异质图自注意力网络模型进行验证;
训练过程中,使用交叉熵损失作为损失函数来衡量所述异质图自注意力网络模型的性能:loss=CrossEntropy(Y,P);
其中,P={p1,p2,p3…pn},表示模型的与测试;Y={y1,y2,y3…yn},表示模型的标签;
步骤S4、利用所述测试集对验证后的异质图自注意力网络模型进行测试,并不断优化所述异质图自注意力网络模型的超参数;
步骤S5、利用测试后的所述异质图自注意力网络模型进行交通异质图网络的节点分类,进而进行交通流预测。
所述步骤S1中,所述预设比例为2:1:7。
所述步骤S1中,所述节点特征矩阵为:
X∈RN×d;
所述邻接矩阵集合为不同类型边的邻接矩阵集合,公式为:
;
其中,X表示节点特征;R表示实数;N表示节点数量;d表示节点特征的输入维度;A表示邻接矩阵;K表示异质图的边的类型数;k表示邻接矩阵编号;
所述邻接矩阵集合可简写为张量A∈RN×N×K。
所述步骤S2中,所述全局自注意力模块用于学习交通异质图网络中各节点在全局的节点特征依赖和节点特征表示;所述全局自注意力模块主要由残差连接、前馈网络(FeedForward)和多头注意力(Multi-Head Attention)组成,其中,残差连接用于缓解模型的过拟合;多头注意力用于学习交通异质图网络节点全局的特征依赖关系;
所述全局自注意力模块的学习过程为:
S211、将所述节点特征矩阵X经过输入嵌入层(Input Embedding层),分别通过三个可学习的矩阵WQ、WK、WV投影为Q、K、V:
Q=XWQ,K=XWK,V=XWV;
其中,WQ∈Rd×dk;WK∈Rd×dk;WV∈Rd×dv;dk=dv=d;
S212、对所述Q、K、V应用归一化的点乘注意力机制计算自注意力矩阵SAttn:
;
其中,softmax()表示归一化指数函数;T表示矩阵转置操作;
S213、并行执行多次归一化的点乘注意力机制,把计算得到的各所述自注意力矩阵SAttn相加取均值,得到节点嵌入XMHead:
;
其中,XMHead∈RN×d,表示经过多头注意力机制学习得到的节点嵌入;Head表示多头注意力机制的头数,即需要执行多头注意力机制的次数;W0∈Rd×dv;
引入多头注意力机制为了稳定自注意力的计算结果;
S214、对所述节点嵌入XMHead与Q做残差连接后进行归一化,并在多头注意力机制中引入残差连接,得到节点嵌入XN1:
XN1=Norm(Q+XMHead(Q,K,V));
其中,XN1∈RN×d,表示经过第一次归一化后得到的节点嵌入;Norm()表示归一化函数;
S215、将所述节点嵌入XN1输送到由两层线性连接层组成的前馈网络,并在两个所述线性连接层之间使用激活函数Relu来增加全局自注意力模块的非线性,得到节点嵌入XFFN:
XFFN=Linear(Relu(Linear(XN1)));
其中,XFFN∈RN×d,表示经过前馈网络后得到的节点嵌入;Linear()表示线性连接层;
S216、对所述节点嵌入XFFN与XN1做残差连接后进行归一化,并引入残差连接,得到节点嵌入XN2:
XN2=Norm(XN1+XFFN);
S217、对所述节点特征矩阵X和节点嵌入XN2进行拼接,得到节点特征表示XG:
XG=X‖XN2;
其中,XG∈RN×2d;‖表示拼接操作。
此步骤为了缓解模型的过拟合,为了防止在经过全局自注意力模块学习过程中丢弃掉一些有用的、原始的节点特征信息,在全局自注意力模块的最外层加入一个做拼接操作的残差连接。
所述步骤S2中,所述图自注意力模块用于学习交通异质图网络中不同类型边和节点特征的表示;所述图自注意力模块由图卷积层(Graph Convolution)和图转换注意力层(Graph Trans-Attention层)交替叠加四层组成;利用图卷积层学习节点及其周围一阶邻居的特征信息;图转换注意力层对多头注意力机制进行了改变,在其学习过程中加入异质图网络边的特征信息,并加入了一个门控单元Gate来防止模型的过平滑;在整个图自注意力模块中的每一个传播层之后,都引入激活函数ReLU来提高模型的非线性拟合能力;
所述图自注意力模块的学习过程为:
S221、把不同类型边所构成的邻接矩阵A聚合在一起,得到新的邻接矩阵AC:
AC=Conv(A;WC)=AWC;
其中,AC∈RN×N;Conv()表示卷积函数;WC∈RK×1×1,表示可学习的参数矩阵;
为了不丢失节点自身的特征,在每种类型关系的邻接矩阵上添加自连接的边,即在聚合前每一个不同类型关系的邻接矩阵Ak加上单位矩阵;
S222、在所述邻接矩阵AC、节点特征表示XG的基础上,利用图卷积层(GraphConvolution)学习交通异质图网络的节点以及其一阶邻居的特征信息,得到节点嵌入XC:
XC=Relu(GraphConv(XG;AC))=Relu(ACXGW);
其中,XC∈RN×dout,表示经过图卷积层学习得到的节点嵌入;dout表示输出的嵌入维度;GraphConv()表示图卷积操作;W∈R2d×dout,表示图卷积的权重矩阵;
S223、在经过图卷积层学习之后,为了能够学习到异质图网络节点特征的高阶语义信息,进一步使用多头注意力机制,特别是在考虑异质图网络结构信息的情况下,将多头注意力机制进行改变,加入异质图网络边的特征信息,设计一个Graph Trans-Attention层来学习交通异质图网络节点特征的高阶信息,即给定节点嵌入XC=[x1,x2…xN]T∈RN×dout,xN∈Rdout,表示节点N的特征表示;对于存在连接边的节点i和节点j,使用可学习参数Wq、Wk、bq、bk,将节点i的特征xi和节点j的特征xj分别转化为qi和kj:
qi=Wqxi+bq;
kj=Wkxj+bk;
其中,qi∈Rdout,kj∈Rdout,均为向量;
S224、将所述邻接矩阵AC通过可学习参数We、be转换为边缘特征eij,将所述边缘特征eij加入向量kj,得到向量kj’:
eij=WeAij+be;
kj’=kj+eij;
其中,Aij为邻接矩阵AC中的元素值,表示节点i和节点j之间存在相连的边;
S225、计算从节点j到节点i的每一条边的归一化点乘注意力αij:
;
;
其中,exp()表示以自然常数e为底的指数函数;N(i)表示节点i基于邻接矩阵AC的一阶邻居节点,包括其自身;
S226、通过可学习参数Wv、bv将节点j的特征xj转换为vj:
vj=Wvxj+bv;
其中,vj∈Rdout;
S227、基于所述vj、αij、eij计算多头注意力,得到节点嵌入zi:
;
即独立计算Head次注意力,取平均值作为节点i的节点嵌入zi;
S228、为了防止模型的过平滑,在Graph Trans-Attention层中,对所述节点嵌入引入门控单元Gate以及残差连接,得到节点嵌入:
ri=Wrxi+br;
di=zi‖ri‖(zi-ri);
;
;
其中,Wr、br、Wg均为可学习参数,且Wg∈R3dout;i表示节点编号;T表示转置操作;‖表示拼接操作;d表示拼接操作后得到的矢量;
S229、对所述节点嵌入进行归一化,得到节点嵌入Zi:
;
其中,Zi∈Rdout;引入激活函数来增加模型的非线性表示能力;
S230、重复两次S221-S229的学习过程,在经过所述图自注意力模块的学习后,获得所有节点最终的节点嵌入Z,Z∈RN×dout。
所述步骤S2中,所述输出模块用于预测节点类别;
所述输出模块的计算过程为:
将所述节点嵌入Z输入两个全连接层(MLP)和softmax函数得到预测的节点类别P:
P=softmax(Linear(Linear(Z)));
其中,P∈R1×n,n表示节点类别数。
所述步骤S4中,所述超参数至少包括随机失活率(dropout)、权值衰减率(weight-decay)以及学习率。
具体实施时,训练总迭代次数为50次,优化器使用Adam;全局自注意力模块的学习率设置为0.0004、权值衰减率设置为0.001;图自注意力模块的学习率设置为0.005、权值衰减率设置为0.001;输出模块的学习率设置为0.001、权值衰减率设置为0.001;根据不同的数据集对随机失活率进行调整,ACM的随机失活率为0.3,DBLP的随机失活率为0.0,IMDB的随机失活率为0.5。
所述异质图自注意力网络模型在ACM、DBLP和IMDB三个异质图网络公共数据集中,使用图节点分类任务指标F1-macro和F1-micro对模型的特征学习能力进行评测,结果表明能够对异质图网络的节点特征进行有效的学习,并且实验的结果超越了传统方法。
综上所述,本发明的优点在于:
通过获取大量的交通异质图网络的数据集并划分为训练集、验证集和测试集,从训练集、验证集和测试集中分别提取异质图网络的节点特征矩阵和邻接矩阵集合;基于全局自注意力模块、图自注意力模块以及输出模块创建一异质图自注意力网络模型,利用训练集对异质图自注意力网络模型进行训练,利用验证集对训练后的异质图自注意力网络模型进行验证,利用测试集对验证后的异质图自注意力网络模型进行测试,并不断优化异质图自注意力网络模型的超参数,最后利用测试后的异质图自注意力网络模型进行交通流预测;由于全局自注意力模块用于学习交通异质图网络中各节点在全局的节点特征依赖和节点特征表示,图自注意力模块用于学习交通异质图网络中不同类型边和节点特征的表示,在整个学习过程中不需要使用元路径,并能够更好学习交通异质图网络丰富的特征信息和高阶语义信息,具有更强大的异质图网络的节点特征学习能力,进而极大的提升了交通流预测的准确率。
虽然以上描述了本发明的具体实施方式,但是熟悉本技术领域的技术人员应当理解,我们所描述的具体的实施例只是说明性的,而不是用于对本发明的范围的限定,熟悉本领域的技术人员在依照本发明的精神所作的等效的修饰以及变化,都应当涵盖在本发明的权利要求所保护的范围内。
Claims (7)
1.一种基于图自注意力机制的异质图网络节点分类方法,其特征在于:包括如下步骤:
步骤S1、获取大量的交通异质图网络的数据集,按预设比例将所述数据集划分为训练集、验证集和测试集,从所述训练集、验证集和测试集中分别提取交通异质图网络的节点特征矩阵和邻接矩阵集合;
步骤S2、基于全局自注意力模块、图自注意力模块以及输出模块创建一异质图自注意力网络模型;
步骤S3、利用所述训练集对异质图自注意力网络模型进行训练,利用所述验证集对训练后的异质图自注意力网络模型进行验证;
步骤S4、利用所述测试集对验证后的异质图自注意力网络模型进行测试,并不断优化所述异质图自注意力网络模型的超参数;
步骤S5、利用测试后的所述异质图自注意力网络模型进行交通异质图网络的节点分类,进而进行交通流预测。
2.如权利要求1所述的一种基于图自注意力机制的异质图网络节点分类方法,其特征在于:所述步骤S1中,所述预设比例为2:1:7。
3.如权利要求1所述的一种基于图自注意力机制的异质图网络节点分类方法,其特征在于:所述步骤S1中,所述节点特征矩阵为:
X∈RN×d;
所述邻接矩阵集合为不同类型边的邻接矩阵集合,公式为:其中,X表示节点特征;R表示实数;N表示节点数量;d表示节点特征的输入维度;A表示邻接矩阵;K表示异质图的边的类型数;k表示邻接矩阵编号。
4.如权利要求3所述的一种基于图自注意力机制的异质图网络节点分类方法,其特征在于:所述步骤S2中,所述全局自注意力模块用于学习交通异质图网络中各节点在全局的节点特征依赖和节点特征表示;
所述全局自注意力模块的学习过程为:
S211、将所述节点特征矩阵X分别通过三个可学习的矩阵WQ、WK、WV投影为Q、K、V:
Q=XWQ,K=XWK,V=XWV;
其中,WQ∈Rd×dk;WK∈Rd×dk;WV∈Rd×dv;dk=dv=d;
S212、对所述Q、K、V应用归一化的点乘注意力机制计算自注意力矩阵SAttn:
;
其中,softmax()表示归一化指数函数;T表示矩阵转置操作;
S213、并行执行多次归一化的点乘注意力机制,把计算得到的各所述自注意力矩阵SAttn相加取均值,得到节点嵌入XMHead:
;
其中,XMHead∈RN×d,表示经过多头注意力机制学习得到的节点嵌入;Head表示多头注意力机制的头数;W0∈Rd×dv;
S214、对所述节点嵌入XMHead与Q做残差连接后进行归一化,得到节点嵌入XN1:
XN1=Norm(Q+XMHead(Q,K,V));
其中,XN1∈RN×d,表示经过第一次归一化后得到的节点嵌入;Norm()表示归一化函数;
S215、将所述节点嵌入XN1输送到由两层线性连接层组成的前馈网络,并在两个所述线性连接层之间使用激活函数Relu来增加全局自注意力模块的非线性,得到节点嵌入XFFN:
XFFN=Linear(Relu(Linear(XN1)));
其中,XFFN∈RN×d,表示经过前馈网络后得到的节点嵌入;Linear()表示线性连接层;
S216、对所述节点嵌入XFFN与XN1做残差连接后进行归一化,得到节点嵌入XN2:
XN2=Norm(XN1+XFFN);
S217、对所述节点特征矩阵X和节点嵌入XN2进行拼接,得到节点特征表示XG:
XG=X‖XN2;
其中,XG∈RN×2d;‖表示拼接操作。
5.如权利要求4所述的一种基于图自注意力机制的异质图网络节点分类方法,其特征在于:所述步骤S2中,所述图自注意力模块用于学习交通异质图网络中不同类型边和节点特征的表示;
所述图自注意力模块的学习过程为:
S221、把不同类型边所构成的邻接矩阵A聚合在一起,得到新的邻接矩阵AC:
AC=Conv(A;WC)=AWC;
其中,AC∈RN×N;Conv()表示卷积函数;WC∈RK×1×1,表示可学习的参数矩阵;
S222、在所述邻接矩阵AC、节点特征表示XG的基础上,利用图卷积层学习交通异质图网络的节点以及其一阶邻居的特征信息,得到节点嵌入XC:
XC=Relu(GraphConv(XG;AC))=Relu(ACXGW);
其中,XC∈RN×dout,表示经过图卷积层学习得到的节点嵌入;dout表示输出的嵌入维度;GraphConv()表示图卷积操作;W∈R2d×dout,表示图卷积的权重矩阵;
S223、给定节点嵌入XC=[x1,x2…xN]T∈RN×dout,xi∈Rdout,表示节点N的特征表示;对于存在连接边的节点i和节点j,使用可学习参数Wq、Wk、bq、bk,将节点i的特征xi和节点j的特征xj分别转化为qi和kj:
qi=Wqxi+bq;
kj=Wkxj+bk;
其中,qi∈Rdout,kj∈Rdout,均为向量;
S224、将所述邻接矩阵AC通过可学习参数We、be转换为边缘特征eij,将所述边缘特征eij加入向量kj,得到向量kj’:
eij=WeAij+be;
kj’=kj+eij;
其中,Aij为邻接矩阵AC中的元素值,表示节点i和节点j之间存在相连的边;
S225、计算从节点j到节点i的每一条边的归一化点乘注意力αij:
;
;
其中,exp()表示以自然常数e为底的指数函数;N(i)表示节点i基于邻接矩阵AC的一阶邻居节点;
S226、通过可学习参数Wv、bv将节点j的特征xj转换为vj:
vj=Wvxj+bv;
其中,vj∈Rdout;
S227、基于所述vj、αij、eij计算多头注意力,得到节点嵌入zi:
;
S228、对所述节点嵌入zi引入门控单元Gate以及残差连接,得到节点嵌入:
ri=Wrxi+br;
di=zi‖ri‖(zi-ri);
;
;
其中,Wr、br、Wg均为可学习参数,且Wg∈R3dout;i表示节点编号;T表示转置操作;‖表示拼接操作;d表示拼接操作后得到的矢量;
S229、对所述节点嵌入进行归一化,得到节点嵌入Zi:
;
其中,Zi∈Rdout;S230、重复两次S221-S229的学习过程,在经过所述图自注意力模块的学习后,获得所有节点最终的节点嵌入Z,Z∈RN×dout。
6.如权利要求5所述的一种基于图自注意力机制的异质图网络节点分类方法,其特征在于:所述步骤S2中,所述输出模块用于预测节点类别;
所述输出模块的计算过程为:
将所述节点嵌入Z输入两个全连接层和softmax函数得到预测的节点类别P:
P=softmax(Linear(Linear(Z)));
其中,P∈R1×n,n表示节点类别数。
7.如权利要求1所述的一种基于图自注意力机制的异质图网络节点分类方法,其特征在于:所述步骤S4中,所述超参数至少包括随机失活率、权值衰减率以及学习率。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311099604.1A CN116821776B (zh) | 2023-08-30 | 2023-08-30 | 一种基于图自注意力机制的异质图网络节点分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311099604.1A CN116821776B (zh) | 2023-08-30 | 2023-08-30 | 一种基于图自注意力机制的异质图网络节点分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116821776A true CN116821776A (zh) | 2023-09-29 |
CN116821776B CN116821776B (zh) | 2023-11-28 |
Family
ID=88114842
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311099604.1A Active CN116821776B (zh) | 2023-08-30 | 2023-08-30 | 一种基于图自注意力机制的异质图网络节点分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116821776B (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117131938A (zh) * | 2023-10-26 | 2023-11-28 | 合肥工业大学 | 基于图深度学习的动态隐性关系挖掘方法和系统 |
CN117218868A (zh) * | 2023-11-07 | 2023-12-12 | 福建理工大学 | 一种基于几何散射图网络的交通流预测方法 |
CN117435995A (zh) * | 2023-12-20 | 2024-01-23 | 福建理工大学 | 一种基于残差图网络的生物医药分类方法 |
CN118035323A (zh) * | 2024-04-12 | 2024-05-14 | 四川航天职业技术学院(四川航天高级技工学校) | 应用于数字化校园软件服务的数据挖掘方法及系统 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112288011A (zh) * | 2020-10-30 | 2021-01-29 | 闽江学院 | 一种基于自注意力深度神经网络的图像匹配方法 |
CN114565053A (zh) * | 2022-03-10 | 2022-05-31 | 天津大学 | 基于特征融合的深层异质图嵌入模型 |
WO2023087558A1 (zh) * | 2021-11-22 | 2023-05-25 | 重庆邮电大学 | 基于嵌入平滑图神经网络的小样本遥感图像场景分类方法 |
CN116597824A (zh) * | 2023-05-19 | 2023-08-15 | 杭州电子科技大学 | 一种基于注意力引导张量网络的想象语音分类方法及系统 |
CN116628597A (zh) * | 2023-07-21 | 2023-08-22 | 福建理工大学 | 一种基于关系路径注意力的异质图节点分类方法 |
-
2023
- 2023-08-30 CN CN202311099604.1A patent/CN116821776B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112288011A (zh) * | 2020-10-30 | 2021-01-29 | 闽江学院 | 一种基于自注意力深度神经网络的图像匹配方法 |
WO2023087558A1 (zh) * | 2021-11-22 | 2023-05-25 | 重庆邮电大学 | 基于嵌入平滑图神经网络的小样本遥感图像场景分类方法 |
CN114565053A (zh) * | 2022-03-10 | 2022-05-31 | 天津大学 | 基于特征融合的深层异质图嵌入模型 |
CN116597824A (zh) * | 2023-05-19 | 2023-08-15 | 杭州电子科技大学 | 一种基于注意力引导张量网络的想象语音分类方法及系统 |
CN116628597A (zh) * | 2023-07-21 | 2023-08-22 | 福建理工大学 | 一种基于关系路径注意力的异质图节点分类方法 |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117131938A (zh) * | 2023-10-26 | 2023-11-28 | 合肥工业大学 | 基于图深度学习的动态隐性关系挖掘方法和系统 |
CN117131938B (zh) * | 2023-10-26 | 2024-01-19 | 合肥工业大学 | 基于图深度学习的动态隐性关系挖掘方法和系统 |
CN117218868A (zh) * | 2023-11-07 | 2023-12-12 | 福建理工大学 | 一种基于几何散射图网络的交通流预测方法 |
CN117218868B (zh) * | 2023-11-07 | 2024-03-22 | 福建理工大学 | 一种基于几何散射图网络的交通流预测方法 |
CN117435995A (zh) * | 2023-12-20 | 2024-01-23 | 福建理工大学 | 一种基于残差图网络的生物医药分类方法 |
CN117435995B (zh) * | 2023-12-20 | 2024-03-19 | 福建理工大学 | 一种基于残差图网络的生物医药分类方法 |
CN118035323A (zh) * | 2024-04-12 | 2024-05-14 | 四川航天职业技术学院(四川航天高级技工学校) | 应用于数字化校园软件服务的数据挖掘方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN116821776B (zh) | 2023-11-28 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116821776B (zh) | 一种基于图自注意力机制的异质图网络节点分类方法 | |
Natesan Ramamurthy et al. | Model agnostic multilevel explanations | |
WO2018212710A1 (en) | Predictive analysis methods and systems | |
CN112861936B (zh) | 一种基于图神经网络知识蒸馏的图节点分类方法及装置 | |
CN112529168A (zh) | 一种基于gcn的属性多层网络表示学习方法 | |
CN113961759B (zh) | 基于属性图表示学习的异常检测方法 | |
US20230195809A1 (en) | Joint personalized search and recommendation with hypergraph convolutional networks | |
CN112257841A (zh) | 图神经网络中的数据处理方法、装置、设备及存储介质 | |
CN116628597B (zh) | 一种基于关系路径注意力的异质图节点分类方法 | |
CN112667824A (zh) | 基于多语义学习的知识图谱补全方法 | |
CN112529071A (zh) | 一种文本分类方法、系统、计算机设备和存储介质 | |
CN113449853A (zh) | 一种图卷积神经网络模型及其训练方法 | |
CN115310837A (zh) | 基于因果图注意力神经网络的复杂机电系统故障检测方法 | |
CN116976505A (zh) | 基于信息共享的解耦注意网络的点击率预测方法 | |
Zhang et al. | An intrusion detection method based on stacked sparse autoencoder and improved gaussian mixture model | |
CN111428181A (zh) | 一种基于广义加性模型结合矩阵分解的银行理财产品推荐方法 | |
Richard et al. | Link discovery using graph feature tracking | |
Raghavendra et al. | Evaluation of feature selection methods for predictive modeling using neural networks in credits scoring | |
CN113159976B (zh) | 一种微博网络重要用户的识别方法 | |
CN112581177B (zh) | 结合自动特征工程及残差神经网络的营销预测方法 | |
CN113626685A (zh) | 一种面向传播不确定性的谣言检测方法及装置 | |
CN112836763A (zh) | 一种图结构数据分类方法及装置 | |
PCD et al. | Advanced lightweight feature interaction in deep neural networks for improving the prediction in click through rate | |
Zhang et al. | An interpretable neural model with interactive stepwise influence | |
US20240355091A1 (en) | Techniques to perform global attribution mappings to provide insights in neural networks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |