CN115661550B - 基于生成对抗网络的图数据类别不平衡分类方法及装置 - Google Patents

基于生成对抗网络的图数据类别不平衡分类方法及装置 Download PDF

Info

Publication number
CN115661550B
CN115661550B CN202211461517.1A CN202211461517A CN115661550B CN 115661550 B CN115661550 B CN 115661550B CN 202211461517 A CN202211461517 A CN 202211461517A CN 115661550 B CN115661550 B CN 115661550B
Authority
CN
China
Prior art keywords
node
nodes
class
representing
graph
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202211461517.1A
Other languages
English (en)
Other versions
CN115661550A (zh
Inventor
张阳
还章军
余婷
张吉
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Lab
Original Assignee
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Lab filed Critical Zhejiang Lab
Priority to CN202211461517.1A priority Critical patent/CN115661550B/zh
Publication of CN115661550A publication Critical patent/CN115661550A/zh
Application granted granted Critical
Publication of CN115661550B publication Critical patent/CN115661550B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Landscapes

  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了基于生成对抗网络的图数据类别不平衡分类方法及装置,通过构建生成器,将属性图信息输入到生成器中,生成器包括低阶神经网络和高阶神经网络,学习图的局部和全局信息,学习并得到节点的嵌入表示向量;再对少样本图数据进行过采样,根据合成少数类过采样技术SMOTE原则,对学习到的每个少数类表示向量进行近邻计算,选择其最近邻节点进行插值计算,生成新的节点;然后重建平衡图数据,通过已有图的节点和边信息训练边生成器,对生成的节点进行链路预测;最后将重建的平衡图数据作为判别器的输入,进行节点分类。有效解决了数据不平衡的假阳性问题,提高了图的节点分类准确率。

Description

基于生成对抗网络的图数据类别不平衡分类方法及装置
技术领域
本发明涉及图数据挖掘、数据不平衡技术领域,尤其是涉及基于生成对抗网络的图数据类别不平衡分类方法及装置。
背景技术
数据不平衡研究是一个经典的机器学习问题,广泛存在于工业生产、计算机视觉、信息安全等诸多领域,是近年来持续研究热点之一。数据不平衡指的是不同类别的数据样本之间的不平衡,目前研究通常针对文本、图片等数值型数据,主要分为数据层面的样本过采样和欠采样,算法层面的代价敏感函数设计以及集成学习三种方法。在过采样技术中,SMOTE(Synthetic Minority Oversampling Technique,合成少数类过采样技术)算法是经典算法之一,常用来解决不平衡问题。然而,直接使用SMOTE方法对每个原少数类样本进行人工样本会带来过拟合问题,而没有考虑近邻样本,从而增加了不同类别样本的重叠。有鉴于此,许多变形算法被提出来克服这种缺陷。一些代表性的工作包括:Boderline-SMOTE,Adaptive Synthetic Sampling,Safe-Level-SMOTE以及SPIDER2算法。
但是上述方法都是应用于数值型数据中。近年来,随着计算机的软硬件发展,图神经网络已经应用到图数据的各个领域并取得不错的成果。图数据的不平衡问题也逐渐被研究学者发现。不少学者发现数据中经常存在倾斜类分布的帕累托失衡现象,这导致模型对多数类的有偏学习,使得模型精准识别少数类数据非常困难。如未考虑数据本身的倾斜分布,模型对数据的学习会产生一个具有欺骗性的高精度度量,在假阴性比假阳性代价更高的情况下,模型倾向于多数类的预测偏差可能会产生不良后果,特别是在异常检测等领域。
目前,在图数据领域的不平衡方法还在初步研究中,2020年提出的DR-GCN方法,在图卷积网络的基础上加入了条件对抗正则化层和潜在分布对齐正则化层用于解决不均衡数据多分类问题。2021年提出的GraphSmote方法,它将SMOTE应用到图数据中,通过过采样解决图的不平衡问题。但是简单的卷积只学习到了节点的局部特征,过采样仍然会导致过拟合问题,无法解决类重叠问题。
发明内容
为解决现有技术的不足,实现图节点分类的准确度的目的,本发明采用如下的技术方案:
一种基于生成对抗网络的图数据类别不平衡分类方法,包括如下步骤:
步骤S1:构建生成器;将属性图信息输入到生成器中,生成器包括低阶神经网络和高阶神经网络,学习图的局部和全局信息,学习并得到节点的嵌入表示向量;
步骤S2:对少样本图数据进行过采样;根据合成少数类过采样技术SMOTE原则,对学习到的每个少数类表示向量进行近邻计算,选择其最近邻节点进行插值计算,生成新的节点,其中K的取值取决于该类与多数类的不平衡比;
步骤S3:重建平衡图数据;通过已有图的节点和边信息训练边生成器,对生成的节点进行链路预测;
步骤S4:将重建的平衡图数据作为判别器的输入,进行节点分类。
进一步地,所述步骤1包括如下步骤:
步骤S1.1:提取图的空间结构;
步骤S1.2:提取图的低阶信息;
步骤S1.3:提取图的高阶信息;
步骤S1.4:将学习到的高阶信息拼接到低阶信息中,得到最后的表示向量;
步骤S1.5:步骤S1.1至步骤S1.4构成的生成器,包括对生成的少数类数据的混淆鉴别器损失和对生成数据的条件约束;
步骤S1.6:通过生成器和判别器动态更新参数,优化学习到的嵌入表示向量zi和生成节点ng
进一步地,所述步骤S1.2中,利用归纳式神经网络GraphSage学习节点的表示,具体如下:首先初始化节点的表示
Figure GDA0004128593710000021
V表示节点集,然后聚合T跳邻居节点的表示
Figure GDA0004128593710000022
t=1,2,3,…,T表示与邻居节点的相邻层数,mean表示求取{·}数组平均值的函数,最后将v节点t-1层邻居节点u的信息拼接到t层节点v的向量表示上,经过全连接层转换,得到节点v第t层的向量表示。
进一步地,所述步骤S1.1中,根据节点带有的属性信息,构造节点属性矩阵;所述步骤S1.3包括如下步骤:
步骤S1.3.1:首先使用图卷积网络GCN学习节点属性和拓扑信息,然后利用K最邻近算法构建原始超边,对每个节点进行近邻计算,并将其与近邻构成基础超边的集合eb
例如节点v经过计算,其超边为ev=knn(xu,xv,K),其中,xv表示节点v经过GCN学习到的嵌入表示,xu表示节点v的邻居节点经过GCN学习到的嵌入表示。
步骤S1.3.2:通过K均值聚类算法K-means对X={x1,x2,…,xN}的节点嵌入表示进行聚类,学习到S个聚类中心,计算每个节点到聚类中心的距离,然后将
Figure GDA0004128593710000031
的聚类中心加入到基础超边的集合eb中,由于超边是由多个节点组成的边,加入的聚类中心增加了节点数量,从而扩大了超边;
步骤S1.3.3:采用
Figure GDA0004128593710000032
表示超边e包含的顶点集合,ke表示超边e包含的顶点个数,/>
Figure GDA0004128593710000033
表示节点v包含的所有超边集合,kv表示包含节点v超边个数;对基础超图进行超图卷积,不断更新节点的表示,超图卷积如公式一所示,先通过多层感知机MLP学习到节点的转移状态矩阵,然后利用一维的双曲图卷积神经网络HGCN学习超边的向量表示,最后归一化超边的信息,将其聚合到节点v上,获得节点最终的高阶表示hv
进一步地,所述步骤S1.3.3中,超图卷积公式如下:
T=MLP(xu)
h′e=HGCN(T·MLP(xu))
W=softmax(h′eW+b)
Figure GDA0004128593710000034
其中,xu表示节点v的邻居节点经过图卷积网络GCN学习到的嵌入表示,T表示通过多层感知机MLP学习到节点的转移状态矩阵,h′e表示通过双曲图卷积神经网络HGCN学习超边的向量表示,W和b分别表示softmax激活函数的权重和偏置,w表示归一化后的超边信息,|Adj(v)|表示节点v包含的所有超边的数量,hv表示节点最终的高阶表示。
进一步地,所述步骤S1.1中,根据图的节点和边信息,构造图的邻接矩阵A;所述步骤S1.2中,使用解码器重构图数据,形成仅包含原始节点的重构邻接矩阵AD;所述步骤S1.4中,使用注意力机制提取两种表示中重要信息进行下一层传播,即最后的表示向量zv=cat[hv,xv],cat表示拼接操作,hv表示高阶信息,xv表示经步骤S1.2得到的节点v经过图卷积网络GCN学习到的嵌入表示,将其输入步骤S2中,生成新节点;
所述步骤S1.5中,对生成的少数类数据的混淆鉴别器损失,包括Lrf用于判断节点是生成节点还是真实节点,通过损失训练使判别器将生成节点识别为真实节点,Lmaj用于控制生成的少数类节点尽可能远离多数类;对生成数据的条件约束,包括Ldis用于将生成的少数类节点靠近真实的少数类节点,Lrec用于控制编码器学习图的真实信息。
Figure GDA0004128593710000041
其中,
Figure GDA0004128593710000042
表示第i个节点在学习到表示向量zi时属于真实类的概率
Figure GDA0004128593710000043
zi表示第i个节点的最终向量表示,real表示节点属于真实类,majority表示节点属于多数类,/>
Figure GDA0004128593710000044
表示第i个节点的预测标签,qi表示第i个节点的真实标签,qj表示第j个节点的真实标签,p(zi)表示第i个节点属于少数类的概率,ng表示生成的节点集,nmin表示少数类节点集,/>
Figure GDA0004128593710000045
表示正则化。
进一步地,所述步骤S2中,计算少数类l中节点v的最近邻节点,
Figure GDA0004128593710000046
s.t.lu=lv,其中/>
Figure GDA0004128593710000047
表示节点v属于类别l的表示向量,/>
Figure GDA0004128593710000048
表示节点u属于类别l的表示向量,nn(v)表示同一类别中距离节点v最近的邻节点,argmin||·||表示取距离最近操作;然后生成新的合成节点/>
Figure GDA0004128593710000049
Figure GDA00041285937100000410
δ表示平衡系数。
进一步地,所述步骤S3中,使用点积操作进行边预测,节点u和v的边概率是
Figure GDA00041285937100000411
边生成器的损失函数loss为/>
Figure GDA00041285937100000412
其中W是线性函数softmax的权重矩阵,E表示图的边集,A表示根据图的节点和边信息构造的图的邻接矩阵,当预测概率大于阈值时,认为节点u和v存在边,通过不断地优化学习,最终得到重构图的边信息。
进一步地,所述步骤S4中,利用谱图神经网络GCN学习节点的嵌入表示,并结合softmax函数进行多类别分类;判别器的损失loss函数如公式三所示,分别是真实节点和生成节点是否是多数类,以及多数类和少数类的交叉熵函数,具体如下:
Figure GDA0004128593710000051
其中,Lfa是用以区分节点是真实节点还是生成器生成的节点的交叉熵损失;Lcl是用以区分节点是少数类还是多数类的交叉熵损失,将节点数最多的一组类别作为多数类,其他类别为少数类,且保留原始的类别信息,减号表示希望数据尽可能远离多数类;Ldis是用以扩大不同的类节点之间的嵌入距离的损失函数;
Figure GDA0004128593710000052
表示第i个节点在学习到表示向量zi时属于伪类fake的概率/>
Figure GDA0004128593710000053
zi表示第i个节点的最终向量表示,minority表示节点属于少数类,/>
Figure GDA0004128593710000054
表示第i个节点的预测标签,qi表示第i个节点的真实标签,qj表示第j个节点的真实标签,p(zi)表示第i个节点属于少数类的概率,ng表示生成的节点集,nmin表示少数类节点集,nmaj表示多数类节点集。
基于生成对抗网络的图数据类别不平衡分类装置,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现所述的基于生成对抗网络的图数据类别不平衡分类方法。
本发明的优势和有益效果在于:
本发明的基于生成对抗网络的图数据类别不平衡分类方法及装置,通过高低阶构图,学习节点的局部和全局信息,并结合生成对抗思路,动态更新生成节点,有效解决了图数据不平衡问题,同时实验表明,本发明优于现有的SOTA方法。
附图说明
图1是本发明实施例中方法的流程图。
图2是本发明实施例中不平衡动态卷积生成对抗网络的原理图。
图3是本发明实施例中cora数据集的实验结果图。
图4a是本发明实施例中不考虑数据本身的不平衡问题时节点分类准确度示意图。
图4b是本发明实施例中考虑数据本身的不平衡问题时节点分类准确度示意图。
图5是本发明实施例中装置的结构示意图。
具体实施方式
以下结合附图对本发明的具体实施方式进行详细说明。应当理解的是,此处所描述的具体实施方式仅用于说明和解释本发明,并不用于限制本发明。
如图1、图2所示,基于生成对抗网络的图数据类别不平衡分类方法,包含如下步骤:
步骤S1:构建生成器。将属性图信息输入到生成器中,生成器包括低阶神经网络和高阶神经网络,学习图的局部和全局信息,学习并得到节点的嵌入表示向量Z,包括如下步骤:步骤S1.1:提取图的空间结构:原始图G=(V,E),V表示节点集,E表示边集。根据图的节点和边信息,构造图的邻接矩阵A∈RN*N,A用于表示图的拓扑结构特征;根据节点带有的属性信息,构造节点属性矩阵F∈RN*M,其中,N表示节点的总数,M表示节点属性空间的总维度。
步骤S1.2:提取图的低阶信息:利用归纳式神经网络GraphSage学习节点的表示,具体如下:首先初始化节点的表示
Figure GDA0004128593710000061
然后聚合T跳邻居节点的表示
Figure GDA0004128593710000062
t=1,2,3,…,T表示与邻居节点的相邻层数,mean表示求取{·}数组平均值的函数,最后将v节点t-1层邻居节点u的信息拼接到t层节点v的向量表示上,经过全连接层转换,得到节点v第t层的向量表示,使用解码器重构图数据,形成仅包含原始节点的重构邻接矩阵AD,具体地,图2中重构邻接矩阵AD,通过对节点的表示h及其转置后的hT进行点积,再通过sigmoid激活函数得到。
步骤S1.3:提取图的高阶信息:该方法提出一种动态超图构建方法,通过不断迭代优化学习节点的高阶信息,并将其与低阶信息融合,以便于步骤S2生成优质的节点。具体如下:
步骤S1.3.1:首先使用GCN(Graph Convolutional Network,图卷积网络)学习节点属性和拓扑信息,然后利用KNN(K-NearestNeighbor,K最邻近)算法构建原始超边,对每个节点进行近邻计算,并将其与近邻构成基础超边的集合eb。例如节点v经过计算,其超边为ev=knn(xu,xv,K),其中,xv表示节点v经过GCN学习到的嵌入表示,xu表示节点v的邻居节点经过GCN学习到的嵌入表示。
步骤S1.3.2:通过K-means对X={x1,x2,…,xN}的节点嵌入表示进行聚类,学习到S个聚类中心,计算每个节点到聚类中心的距离,然后将
Figure GDA0004128593710000071
的聚类中心加入到eb中,由于超边是由多个节点组成的边,加入的聚类中心增加了节点数量,从而扩大了超边。
步骤S1.3.3:采用
Figure GDA0004128593710000072
表示超边e包含的顶点集合,ke表示超边e包含的顶点个数,/>
Figure GDA0004128593710000073
表示节点v包含的所有超边集合,kv表示包含节点v超边个数;对基础超图进行超图卷积,不断更新节点的表示,超图卷积如公式一所示,先通过MLP(Multilayer Perceptron,多层感知机)学习到节点的转移状态矩阵,然后利用一维的HGCN学习超边的向量表示,最后归一化超边的信息,将其聚合到节点v上,获得节点最终的高阶表示hv
T=MLP(xu)
Figure GDA0004128593710000074
其中,T表示通过MLP学习到节点的转移状态矩阵,h′e表示通过HGCN(HyperbolicGraph Convolutional Neural Network,双曲图卷积神经网络)学习超边的向量表示,W和b分别表示softmax激活函数的权重和偏置,w表示归一化后的超边信息,|Adj(v)|表示节点v包含的所有超边的数量。
步骤S1.4:将学习到的高阶信息拼接到低阶信息中,并使用注意力机制提取两种表示中重要信息进行下一层传播,即最后的表示向量zv=cat[hv,xv],将其输入步骤S2中,生成新节点。
步骤S1.5:步骤S1.1至步骤S1.4构成的生成器,生成器的loss函数如公式二所示,同样由四部分组成,前两项是对生成的少数类数据的混淆鉴别器损失,Lrf用于判断节点是生成节点还是真实节点,通过损失训练使判别器将生成节点识别为真实节点,Lmaj用于控制生成的少数类节点尽可能远离多数类。后两项是对生成数据的条件约束,Ldis的目的是将生成的少数类节点靠近真实的少数类节点,Lrec的目的是控制编码器学习图的真实信息;
Figure GDA0004128593710000081
其中,
Figure GDA0004128593710000082
表示第i个节点在学习到表示向量zi时属于真实类的概率
Figure GDA0004128593710000083
zi表示第i个节点的最终向量表示,real表示节点属于真实类,majority表示节点属于多数类,/>
Figure GDA0004128593710000084
表示第i个节点的预测标签,qi表示第i个节点的真实标签,qj表示第j个节点的真实标签,p(zi)表示第i个节点属于少数类的概率,ng表示生成的节点集,nmin表示少数类节点集,/>
Figure GDA0004128593710000085
表示正则化,具体地,图2中Lrec即为/>
Figure GDA0004128593710000086
步骤S1.6:通过生成器和判别器动态更新模型参数,优化学习到的嵌入向量zi和生成节点ng
步骤S2:对少样本图数据进行过采样。根据SMOTE原则,对学习到的每个少数类表示向量Zl进行K近邻计算,选择其最近邻节点进行插值计算,生成新的节点Ng,其中K的取值取决于该类与多数类的不平衡比。
例如,计算少数类l中节点v的最近邻节点,
Figure GDA0004128593710000087
s.t.lu=lv,其中/>
Figure GDA0004128593710000088
表示节点v属于类别l的表示向量,/>
Figure GDA0004128593710000089
表示节点u属于类别l的表示向量,nn(v)表示同一类别中距离节点v最近的邻节点,argmin||·||表示取距离最近操作;然后生成新的合成节点/>
Figure GDA00041285937100000810
δ表示平衡系数。
步骤S3:重建平衡图数据G’。通过已有图的节点和边信息训练边生成器,对生成的节点进行链路预测。
具体地,该方法使用点积操作进行边预测。节点u和v的边概率是
Figure GDA00041285937100000811
边生成器的损失函数loss为/>
Figure GDA00041285937100000812
其中W是线性函数的权重矩阵。当预测概率大于阈值0.5时,我们则认为节点u和v存在边。通过不断地优化学习,最终得到重构图的边信息。
步骤S4:将重建的平衡图数据G’作为判别器的输入,进行节点分类。在这里我们利用谱图神经网络GCN学习节点的嵌入表示,并结合softmax函数进行多类别分类。判别器的损失loss函数如公式三所示,分别是真实节点和生成节点是否是多数类,以及多数类和少数类的交叉熵函数。其中,第二项Lcl是因为该方法将节点数top1的类别当做多数类,其他均为少数类,且保留原始的类别信息。这里使用减号,是希望数据尽可能远离多数类。
Figure GDA0004128593710000091
其中,Lfa是用以区分节点是真实节点还是生成器生成的节点的交叉熵损失;Lcl是用以区分节点是少数类还是多数类的交叉熵损失,将节点数最多的一组类别作为多数类,其他类别为少数类,且保留原始的类别信息,减号表示希望数据尽可能远离多数类;Ldis是用以扩大不同的类节点之间的嵌入距离的损失函数;
Figure GDA0004128593710000092
表示第i个节点在学习到表示向量zi时属于伪类fake的概率/>
Figure GDA0004128593710000093
zi表示第i个节点的最终向量表示,minority表示节点属于少数类,/>
Figure GDA0004128593710000094
表示第i个节点的预测标签,qi表示第i个节点的真实标签,qj表示第j个节点的真实标签,p(zi)表示第i个节点属于少数类的概率,ng表示生成的节点集,nmin表示少数类节点集,nmaj表示多数类节点集。
本发明实施例中,以cora数据集作为输入的图数据,来举例说明:
步骤S101,输入图数据。该数据集包含节点总数2708,节点特征总维度1433。节点的邻接矩阵是一个2708*2708维的矩阵,存储了每个节点的邻节点信息。节点属性矩阵F是一个2708*1433维矩阵,存储了每个节点的属性信息。根据节点对应的论文内容可将其分为7类,其中神经网络占比30.21%,遗传算法占比15.44%,概率学方法15.73%,理论12.96%,基于案例11%,增强学习8.01%,规则学习占比6.65%。
步骤S102,使用图卷积神经网络GraphSage学习节点和属性的融合信息,得到嵌入向量h。
步骤S103,使用KNN和K-means初始化超图,然后使用超图卷积学习节点的高阶信息,并不断迭代更新超图,获得嵌入向量x。
步骤S104,融合低阶表示h和高阶的嵌入表示x,得到节点的整体表示Z,根据SMOTE过采样规则,对少数类节点进行过采样,这里除了神经网络类别是多数类之外,其他6类均为少数类需要进行节点生成。例如初始类别训练集为[237,164,288,561,291,228,126],则我们需对其他节点进行补齐,得到最终的训练数据集为[561,561,561,561,561,561,561]。同时我们通过判别器反馈以及生成器损失loss进行更新。
步骤S105,将补齐之后的数据输入到判别器中,对模型进行训练,最终的分类结果如图3所示,与其他方法对比结果如表1所示。
表1对比实验结果表
recall f1 auc acc pre
GCN 0.6442 0.6245 0.8435 0.6654 0.6892
Smote 0.6883 0.6897 0.9038 0.6883 0.7033
Graph-smote 0.726 0.7153 0.9275 0.726 0.7423
GraphENS 0.6848 0.6915 0.9204 0.736 0.7509
imGANSmote 0.857 0.8452 0.9646 0.8586 0.8365
表中,基于Recall(召回率)评价指标、F1(F1值)评价指标、Auc(Area UnderCurve,曲线下面积)评价指标、Acc(Accuracy,准确率)评价指标、Pre(精确率)评价指标,分别对GCN(图卷积神经网络)方法、SMOTE(合成少数类过采样技术)方法、GraphSMOTE(基于GNN的合成少数类过采样技术)方法和本发明的imGANSmote(基于生成对抗网络的图数据类别不平衡分类方法)方法进行评价的结果,根据实验结果可以看出,本发明的方法在图像分类精度上由于其他方法。
步骤S106,为了进一步确定该方法的有效性,我们进行了数据不平衡的消融实验。如图4a所示,imGANSmote当不考虑数据本身的不平衡问题时,节点分类准确度达到0.73(对角的准确度均值),但仔细分析结果会发现类别4、5、6的预测结果并不理想。如图4b所示,当考虑数据不平衡问题时,节点分类的正确性达到0.805,且每一类准确度都达到了0.69以上。
与前述基于生成对抗网络的图数据类别不平衡分类方法的实施例相对应,本发明还提供了基于生成对抗网络的图数据类别不平衡分类装置的实施例。
参见图5,本发明实施例提供的基于生成对抗网络的图数据类别不平衡分类装置,包括存储器和一个或多个处理器,存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现上述实施例中的基于生成对抗网络的图数据类别不平衡分类方法。
本发明基于生成对抗网络的图数据类别不平衡分类装置的实施例可以应用在任意具备数据处理能力的设备上,该任意具备数据处理能力的设备可以为诸如计算机等设备或装置。装置实施例可以通过软件实现,也可以通过硬件或者软硬件结合的方式实现。以软件实现为例,作为一个逻辑意义上的装置,是通过其所在任意具备数据处理能力的设备的处理器将非易失性存储器中对应的计算机程序指令读取到内存中运行形成的。从硬件层面而言,如图5所示,为本发明基于生成对抗网络的图数据类别不平衡分类装置所在任意具备数据处理能力的设备的一种硬件结构图,除了图5所示的处理器、内存、网络接口、以及非易失性存储器之外,实施例中装置所在的任意具备数据处理能力的设备通常根据该任意具备数据处理能力的设备的实际功能,还可以包括其他硬件,对此不再赘述。
上述装置中各个单元的功能和作用的实现过程具体详见上述方法中对应步骤的实现过程,在此不再赘述。
对于装置实施例而言,由于其基本对应于方法实施例,所以相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本发明方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
本发明实施例还提供一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现上述实施例中的基于生成对抗网络的图数据类别不平衡分类方法。
所述计算机可读存储介质可以是前述任一实施例所述的任意具备数据处理能力的设备的内部存储单元,例如硬盘或内存。所述计算机可读存储介质也可以是任意具备数据处理能力的设备的外部存储设备,例如所述设备上配备的插接式硬盘、智能存储卡(Smart Media Card,SMC)、SD卡、闪存卡(Flash Card)等。进一步的,所述计算机可读存储介质还可以既包括任意具备数据处理能力的设备的内部存储单元也包括外部存储设备。所述计算机可读存储介质用于存储所述计算机程序以及所述任意具备数据处理能力的设备所需的其他程序和数据,还可以用于暂时地存储已经输出或者将要输出的数据。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的范围。

Claims (10)

1.一种基于生成对抗网络的图数据类别不平衡分类方法,其特征在于包括如下步骤:
步骤S1:构建生成器;将属性图信息输入到生成器中,生成器包括低阶神经网络和高阶神经网络,学习图的局部和全局信息,学习并得到节点的嵌入表示向量,属性图信息包括图的节点和边信息,节点为论文节点,对应论文内容,带有论文属性信息及论文类别信息,边表示论文间引用关系;
步骤S2:对少样本图数据进行过采样;根据合成少数类过采样技术SMOTE原则,对学习到的每个少数类表示向量进行近邻计算,选择其最近邻节点进行插值计算,计算少数论文类别中该论文节点的最近邻论文节点,即同一论文类别中距离论文节点最近的论文邻节点,然后生成新的合成节点;
步骤S3:重建平衡图数据;通过已有图的论文节点和边信息训练边生成器,对生成的论文节点进行链路预测;
步骤S4:将重建的平衡图数据作为判别器的输入,进行节点分类,即论文类别的判别。
2.根据权利要求1所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤1包括如下步骤:
步骤S1.1:提取图的空间结构;
步骤S1.2:提取图的低阶信息;
步骤S1.3:提取图的高阶信息;
步骤S1.4:将学习到的高阶信息拼接到低阶信息中,得到最后的表示向量;
步骤S1.5:步骤S1.1至步骤S1.4构成的生成器,包括对生成的少数类数据的混淆鉴别器损失和对生成数据的条件约束;
步骤S1.6:通过生成器和判别器动态更新参数,优化学习到的嵌入表示向量和生成节点。
3.根据权利要求2所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S1.2中,首先初始化节点的表示
Figure FDA0004128593700000011
V表示节点集,然后聚合T跳邻居节点的表示/>
Figure FDA0004128593700000012
T表示与邻居节点的相邻层数,mean表示求取{·}数组平均值的函数,最后将v节点t-1层邻居节点u的信息拼接到t层节点v的向量表示上,得到节点v第t层的向量表示,N表示节点的总数。
4.根据权利要求2所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S1.1中,根据节点带有的属性信息,构造节点属性矩阵;所述步骤S1.3包括如下步骤:
步骤S1.3.1:首先使用图卷积网络GCN学习节点属性和拓扑信息,然后利用K最邻近算法构建原始超边,对每个节点进行近邻计算,并将其与近邻构成基础超边的集合;
步骤S1.3.2:通过K均值聚类算法K-means对节点嵌入表示进行聚类,学习到聚类中心,计算每个节点到聚类中心的距离,然后将聚类中心加入到基础超边的集合中;
步骤S1.3.3:对基础超图进行超图卷积,不断更新节点的表示,超图卷积,先通过多层感知机MLP学习到节点的转移状态矩阵,然后利用一维的双曲图卷积神经网络HGCN学习超边的向量表示,最后归一化超边的信息,将其聚合到节点上,获得节点最终的高阶表示。
5.根据权利要求4所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S1.3.3中,超图卷积公式如下:
T=MLP(xu)
h′e=HGCN(T·MLP(xu))
w=softmax(h′eW+b)
Figure FDA0004128593700000021
其中,xu表示节点v的邻居节点经过图卷积网络GCN学习到的嵌入表示,T表示通过多层感知机MLP学习到节点的转移状态矩阵,h′e表示通过双曲图卷积神经网络HGCN学习超边的向量表示,W和b分别表示softmax激活函数的权重和偏置,w表示归一化后的超边信息,|Adj(v)|表示节点v包含的所有超边的数量,hv表示节点最终的高阶表示。
6.根据权利要求2所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S1.1中,根据图的节点和边信息,构造图的邻接矩阵A;所述步骤S1.2中,使用解码器重构图数据,形成仅包含原始节点的重构邻接矩阵AD;所述步骤S1.4中,使用注意力机制提取两种表示中重要信息进行下一层传播,即最后的表示向量zv=cat[hv,xv],cat表示拼接操作,hv表示高阶信息,xv表示经步骤S1.2得到的节点v经过图卷积网络GCN学习到的嵌入表示;
所述步骤S1.5中,对生成的少数类数据的混淆鉴别器损失,包括Lrf用于判断节点是生成节点还是真实节点,通过损失训练使判别器将生成节点识别为真实节点,Lmaj用于控制生成的少数类节点尽可能远离多数类;对生成数据的条件约束,包括Ldis用于将生成的少数类节点靠近真实的少数类节点,Lrec用于控制编码器学习图的真实信息;
Figure FDA0004128593700000031
其中,
Figure FDA0004128593700000032
表示第i个节点在学习到表示向量zi时属于真实类的概率/>
Figure FDA0004128593700000033
zi表示第i个节点的最终向量表示,real表示节点属于真实类,majority表示节点属于多数类,
Figure FDA0004128593700000034
表示第i个节点的预测标签,qi表示第i个节点的真实标签,qj表示第j个节点的真实标签,p(zi)表示第i个节点属于少数类的概率,ng表示生成的节点集,nmin表示少数类节点集,/>
Figure FDA0004128593700000035
表示正则化。
7.根据权利要求1所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S2中,计算少数类l中节点v的最近邻节点,
Figure FDA0004128593700000036
s.t.lu=lv,其中/>
Figure FDA0004128593700000037
表示节点v属于类别l的表示向量,/>
Figure FDA0004128593700000038
表示节点u属于类别l的表示向量,nn(v)表示同一类别中距离节点v最近的邻节点,argmin||·||表示取距离最近操作;然后生成新的合成节点/>
Figure FDA0004128593700000039
Figure FDA00041285937000000310
δ表示平衡系数。
8.根据权利要求7所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S3中,使用点积操作进行边预测,节点u和v的边概率是
Figure FDA00041285937000000311
边生成器的损失函数loss为/>
Figure FDA00041285937000000312
其中W是线性函数softmax的权重矩阵,E表示图的边集,A表示根据图的节点和边信息构造的图的邻接矩阵,当预测概率大于阈值时,认为节点u和v存在边,通过不断地优化学习,最终得到重构图的边信息。
9.根据权利要求1所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S4中,利用谱图神经网络GCN学习节点的嵌入表示,并结合softmax函数进行多类别分类;判别器的损失loss函数,分别是真实节点和生成节点是否是多数类,以及多数类和少数类的交叉熵函数,具体如下:
Figure FDA0004128593700000041
其中,Lfa是用以区分节点是真实节点还是生成器生成的节点的交叉熵损失;Lcl是用以区分节点是少数类还是多数类的交叉熵损失,将节点数最多的一组类别作为多数类,其他类别为少数类,且保留原始的类别信息,减号表示希望数据尽可能远离多数类;Ldis是用以扩大不同的类节点之间的嵌入距离的损失函数;
Figure FDA0004128593700000042
表示第i个节点在学习到表示向量zi时属于伪类fake的概率/>
Figure FDA0004128593700000043
zi表示第i个节点的最终向量表示,minority表示节点属于少数类,/>
Figure FDA0004128593700000044
表示第i个节点的预测标签,qi表示第i个节点的真实标签,qj表示第j个节点的真实标签,p(zi)表示第i个节点属于少数类的概率,ng表示生成的节点集,nmin表示少数类节点集,nmaj表示多数类节点集,N表示节点的总数。
10.一种基于生成对抗网络的图数据类别不平衡分类装置,其特征在于,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现权利要求1-9中任一项所述的基于生成对抗网络的图数据类别不平衡分类方法。
CN202211461517.1A 2022-11-17 2022-11-17 基于生成对抗网络的图数据类别不平衡分类方法及装置 Active CN115661550B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211461517.1A CN115661550B (zh) 2022-11-17 2022-11-17 基于生成对抗网络的图数据类别不平衡分类方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211461517.1A CN115661550B (zh) 2022-11-17 2022-11-17 基于生成对抗网络的图数据类别不平衡分类方法及装置

Publications (2)

Publication Number Publication Date
CN115661550A CN115661550A (zh) 2023-01-31
CN115661550B true CN115661550B (zh) 2023-05-30

Family

ID=85018043

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211461517.1A Active CN115661550B (zh) 2022-11-17 2022-11-17 基于生成对抗网络的图数据类别不平衡分类方法及装置

Country Status (1)

Country Link
CN (1) CN115661550B (zh)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116628538A (zh) * 2023-07-26 2023-08-22 之江实验室 基于图对齐神经网络的患者聚类方法、装置和计算机设备
CN116721441B (zh) * 2023-08-03 2024-01-19 厦门瞳景智能科技有限公司 基于区块链的门禁安全管理方法与系统
CN116936108B (zh) * 2023-09-19 2024-01-02 之江实验室 一种面向不平衡数据的疾病预测系统

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
GB201910720D0 (en) * 2019-07-26 2019-09-11 Tomtom Global Content Bv Generative adversarial Networks for image segmentation
WO2020163970A1 (en) * 2019-02-15 2020-08-20 Surgical Safety Technologies Inc. System and method for adverse event detection or severity estimation from surgical data
CN111597887A (zh) * 2020-04-08 2020-08-28 北京大学 一种行人再识别方法及系统
CN115130509A (zh) * 2022-06-29 2022-09-30 哈尔滨工业大学(威海) 基于条件式变分自编码器的心电信号生成方法

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2020163970A1 (en) * 2019-02-15 2020-08-20 Surgical Safety Technologies Inc. System and method for adverse event detection or severity estimation from surgical data
GB201910720D0 (en) * 2019-07-26 2019-09-11 Tomtom Global Content Bv Generative adversarial Networks for image segmentation
CN111597887A (zh) * 2020-04-08 2020-08-28 北京大学 一种行人再识别方法及系统
CN115130509A (zh) * 2022-06-29 2022-09-30 哈尔滨工业大学(威海) 基于条件式变分自编码器的心电信号生成方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
Karras T.等.Analyzing and improving the image quality of stylegan.《Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition》.2020,全文. *
俞彬.基于生成对抗网络的图像类别不平衡问题数据扩充方法.《知网》.2018,全文. *

Also Published As

Publication number Publication date
CN115661550A (zh) 2023-01-31

Similar Documents

Publication Publication Date Title
CN110263227B (zh) 基于图神经网络的团伙发现方法和系统
CN115661550B (zh) 基于生成对抗网络的图数据类别不平衡分类方法及装置
Wang et al. A deep convolutional neural network for topology optimization with perceptible generalization ability
Schulz et al. Deep learning: Layer-wise learning of feature hierarchies
Wang A hybrid sampling SVM approach to imbalanced data classification
KR102295805B1 (ko) 학습 데이터 관리 방법
Laha Building contextual classifiers by integrating fuzzy rule based classification technique and k-nn method for credit scoring
Joy et al. Batch Bayesian optimization using multi-scale search
CN112115998B (zh) 一种基于对抗增量聚类动态路由网络克服灾难性遗忘的方法
US20210264209A1 (en) Method for generating anomalous data
AghaeiRad et al. Improve credit scoring using transfer of learned knowledge from self-organizing map
US11816554B2 (en) Method and apparatus for generating weather data based on machine learning
Du et al. Polyline simplification based on the artificial neural network with constraints of generalization knowledge
Hong et al. Variational gridded graph convolution network for node classification
Qu et al. Effects of loss function and data sparsity on smooth manifold extraction with deep model
Dan et al. Pf-vit: Parallel and fast vision transformer for offline handwritten chinese character recognition
KR20220099409A (ko) 딥러닝 모델을 사용한 분류 방법
CN112541530A (zh) 针对聚类模型的数据预处理方法及装置
US20210256374A1 (en) Method and apparatus with neural network and training
Jin Handwritten digit recognition based on classical machine learning methods
Liang et al. Supervised and unsupervised learning models
Silaparasetty et al. Neural Network Collection
KR102579684B1 (ko) 신경망 학습모델을 이용한 디지털 휴먼 모델링 방법
US20230214674A1 (en) Method Of Training Object Prediction Models Using Ambiguous Labels
Yu et al. GRAHIES: Multi-scale graph representation learning with latent hierarchical structure

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant