CN115661550A - 基于生成对抗网络的图数据类别不平衡分类方法及装置 - Google Patents

基于生成对抗网络的图数据类别不平衡分类方法及装置 Download PDF

Info

Publication number
CN115661550A
CN115661550A CN202211461517.1A CN202211461517A CN115661550A CN 115661550 A CN115661550 A CN 115661550A CN 202211461517 A CN202211461517 A CN 202211461517A CN 115661550 A CN115661550 A CN 115661550A
Authority
CN
China
Prior art keywords
node
nodes
graph
class
information
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202211461517.1A
Other languages
English (en)
Other versions
CN115661550B (zh
Inventor
张阳
还章军
余婷
张吉
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Lab
Original Assignee
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Lab filed Critical Zhejiang Lab
Priority to CN202211461517.1A priority Critical patent/CN115661550B/zh
Publication of CN115661550A publication Critical patent/CN115661550A/zh
Application granted granted Critical
Publication of CN115661550B publication Critical patent/CN115661550B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Landscapes

  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了基于生成对抗网络的图数据类别不平衡分类方法及装置,通过构建生成器,将属性图信息输入到生成器中,生成器包括低阶神经网络和高阶神经网络,学习图的局部和全局信息,学习并得到节点的嵌入表示向量;再对少样本图数据进行过采样,根据合成少数类过采样技术SMOTE原则,对学习到的每个少数类表示向量进行近邻计算,选择其最近邻节点进行插值计算,生成新的节点;然后重建平衡图数据,通过已有图的节点和边信息训练边生成器,对生成的节点进行链路预测;最后将重建的平衡图数据作为判别器的输入,进行节点分类。有效解决了数据不平衡的假阳性问题,提高了图的节点分类准确率。

Description

基于生成对抗网络的图数据类别不平衡分类方法及装置
技术领域
本发明涉及图数据挖掘、数据不平衡技术领域,尤其是涉及基于生成对抗网络的图数据类别不平衡分类方法及装置。
背景技术
数据不平衡研究是一个经典的机器学习问题,广泛存在于工业生产、计算机视觉、信息安全等诸多领域,是近年来持续研究热点之一。数据不平衡指的是不同类别的数据样本之间的不平衡,目前研究通常针对文本、图片等数值型数据,主要分为数据层面的样本过采样和欠采样,算法层面的代价敏感函数设计以及集成学习三种方法。在过采样技术中,SMOTE(Synthetic Minority Oversampling Technique,合成少数类过采样技术)算法是经典算法之一,常用来解决不平衡问题。然而,直接使用SMOTE方法对每个原少数类样本进行人工样本会带来过拟合问题,而没有考虑近邻样本,从而增加了不同类别样本的重叠。有鉴于此,许多变形算法被提出来克服这种缺陷。一些代表性的工作包括:Boderline-SMOTE,Adaptive Synthetic Sampling,Safe-Level-SMOTE以及SPIDER2算法。
但是上述方法都是应用于数值型数据中。近年来,随着计算机的软硬件发展,图神经网络已经应用到图数据的各个领域并取得不错的成果。图数据的不平衡问题也逐渐被研究学者发现。不少学者发现数据中经常存在倾斜类分布的帕累托失衡现象,这导致模型对多数类的有偏学习,使得模型精准识别少数类数据非常困难。如未考虑数据本身的倾斜分布,模型对数据的学习会产生一个具有欺骗性的高精度度量,在假阴性比假阳性代价更高的情况下,模型倾向于多数类的预测偏差可能会产生不良后果,特别是在异常检测等领域。
目前,在图数据领域的不平衡方法还在初步研究中,2020年提出的DR-GCN方法,在图卷积网络的基础上加入了条件对抗正则化层和潜在分布对齐正则化层用于解决不均衡数据多分类问题。2021年提出的GraphSmote方法,它将SMOTE应用到图数据中,通过过采样解决图的不平衡问题。但是简单的卷积只学习到了节点的局部特征,过采样仍然会导致过拟合问题,无法解决类重叠问题。
发明内容
为解决现有技术的不足,实现图节点分类的准确度的目的,本发明采用如下的技术方案:
一种基于生成对抗网络的图数据类别不平衡分类方法,包括如下步骤:
步骤S1:构建生成器;将属性图信息输入到生成器中,生成器包括低阶神经网络和高阶神经网络,学习图的局部和全局信息,学习并得到节点的嵌入表示向量;
步骤S2:对少样本图数据进行过采样;根据合成少数类过采样技术SMOTE原则,对学习到的每个少数类表示向量进行近邻计算,选择其最近邻节点进行插值计算,生成新的节点,其中K的取值取决于该类与多数类的不平衡比;
步骤S3:重建平衡图数据;通过已有图的节点和边信息训练边生成器,对生成的节点进行链路预测;
步骤S4:将重建的平衡图数据作为判别器的输入,进行节点分类。
进一步地,所述步骤1包括如下步骤:
步骤S1.1:提取图的空间结构;
步骤S1.2:提取图的低阶信息;
步骤S1.3:提取图的高阶信息;
步骤S1.4:将学习到的高阶信息拼接到低阶信息中,得到最后的表示向量;
步骤S1.5:步骤S1.1至步骤S1.4构成的生成器,包括对生成的少数类数据的混淆鉴别器损失和对生成数据的条件约束;
步骤S1.6:通过生成器和判别器动态更新参数,优化学习到的嵌入表示向量z i 和生成节点n g
进一步地,所述步骤S1.2中,利用归纳式神经网络GraphSage学习节点的表示,具 体如下:首先初始化节点的表示
Figure 266790DEST_PATH_IMAGE001
,∀vVV表示节点集,然后聚合T跳邻居节点的表示
Figure 107707DEST_PATH_IMAGE002
t=1,2,3,…,T表示与邻居节点的相邻层数,mean表示 求取{·}数组平均值的函数,最后将v节点t-1层邻居节点u的信息拼接到t层节点v的向量 表示上,经过全连接层转换,得到节点vt层的向量表示。
进一步地,所述步骤S1.1中,根据节点带有的属性信息,构造节点属性矩阵;所述步骤S1.3包括如下步骤:
步骤S1.3.1:首先使用图卷积网络GCN学习节点属性和拓扑信息,然后利用K最邻近算法构建原始超边,对每个节点进行近邻计算,并将其与近邻构成基础超边的集合e b
例如节点v经过计算,其超边为
Figure 444011DEST_PATH_IMAGE003
,其中,x v 表示节点v经过GCN学 习到的嵌入表示,x u 表示节点v的邻居节点经过GCN学习到的嵌入表示。
步骤S1.3.2:通过K均值聚类算法K-means对X={x 1,x 2,…,x N }的节点嵌入表示进行 聚类,学习到S个聚类中心,计算每个节点到聚类中心的距离,然后将
Figure 828243DEST_PATH_IMAGE004
的聚类中心加入 到基础超边的集合e b 中,由于超边是由多个节点组成的边,加入的聚类中心增加了节点数 量,从而扩大了超边;
步骤S1.3.3:采用
Figure 199181DEST_PATH_IMAGE005
表示超边e包含的顶点集合,k e 表示超 边e包含的顶点个数,
Figure 945421DEST_PATH_IMAGE006
表示节点v包含的所有超边集合,k v 表示包含 节点v超边个数;对基础超图进行超图卷积,不断更新节点的表示,超图卷积如公式一所示, 先通过多层感知机MLP学习到节点的转移状态矩阵,然后利用一维的双曲图卷积神经网络 HGCN学习超边的向量表示,最后归一化超边的信息,将其聚合到节点v上,获得节点最终的 高阶表示h v
进一步地,所述步骤S1.3.3中,超图卷积公式如下:
Figure 769020DEST_PATH_IMAGE007
Figure 954014DEST_PATH_IMAGE008
Figure 913880DEST_PATH_IMAGE009
Figure 96599DEST_PATH_IMAGE010
其中,x u 表示节点v的邻居节点经过图卷积网络GCN学习到的嵌入表示,T表示通过 多层感知机MLP学习到节点的转移状态矩阵,
Figure 469812DEST_PATH_IMAGE011
表示通过双曲图卷积神经网络HGCN学习超 边的向量表示,Wb分别表示softmax激活函数的权重和偏置,w表示归一化后的超边信息, |Adj(v)|表示节点v包含的所有超边的数量,h v 表示节点最终的高阶表示。
进一步地,所述步骤S1.1中,根据图的节点和边信息,构造图的邻接矩阵A;所述步骤S1.2中,使用解码器重构图数据,形成仅包含原始节点的重构邻接矩阵AD;所述步骤S1.4中,使用注意力机制提取两种表示中重要信息进行下一层传播,即最后的表示向量z v =cat[h v ,x v ],cat表示拼接操作,h v 表示高阶信息,x v 表示经步骤S1.2得到的节点v经过图卷积网络GCN学习到的嵌入表示,将其输入步骤S2中,生成新节点;
所述步骤S1.5中,对生成的少数类数据的混淆鉴别器损失,包括L rf 用于判断节点是生成节点还是真实节点,通过损失训练使判别器将生成节点识别为真实节点,L maj 用于控制生成的少数类节点尽可能远离多数类;对生成数据的条件约束,包括L dis 用于将生成的少数类节点靠近真实的少数类节点,L rec 用于控制编码器学习图的真实信息。
Figure 130600DEST_PATH_IMAGE012
其中,
Figure 7289DEST_PATH_IMAGE013
表示第i个节点在学习到表示向量z i 时属于真实类的概率
Figure 626490DEST_PATH_IMAGE014
,z i 表示第i个节点的最终向量表示,real表示节点属于真实类,majority表示节点属于 多数类,
Figure 159102DEST_PATH_IMAGE015
表示第i个节点的预测标签,q i 表示第i个节点的真实标签,q j 表示第j个节点的真 实标签,p(z i )表示第i个节点属于少数类的概率,n g 表示生成的节点集,n min 表示少数类节点 集,
Figure 417389DEST_PATH_IMAGE016
表示正则化。
进一步地,所述步骤S2中,计算少数类l中节点v的最近邻节点,
Figure 351847DEST_PATH_IMAGE017
s.t. l u =l v ,其中
Figure 938686DEST_PATH_IMAGE018
表示节点v属于类别l的表示向量,
Figure 958595DEST_PATH_IMAGE019
表示 节点u属于类别l的表示向量,nn(v)表示同一类别中距离节点v最近的邻节点,argmin||·| |表示取距离最近操作;然后生成新的合成节点
Figure 23503DEST_PATH_IMAGE020
Figure 78047DEST_PATH_IMAGE021
,δ表示平 衡系数。
进一步地,所述步骤S3中,使用点积操作进行边预测,节点uv的边概率是
Figure 773470DEST_PATH_IMAGE022
,边生成器的损失函数loss为
Figure 77413DEST_PATH_IMAGE023
,其中W是 线性函数softmax的权重矩阵,E表示图的边集,A表示根据图的节点和边信息构造的图的邻 接矩阵,当预测概率大于阈值时,认为节点uv存在边,通过不断地优化学习,最终得到重 构图的边信息。
进一步地,所述步骤S4中,利用谱图神经网络GCN学习节点的嵌入表示,并结合softmax函数进行多类别分类;判别器的损失loss函数如公式三所示,分别是真实节点和生成节点是否是多数类,以及多数类和少数类的交叉熵函数,具体如下:
Figure 149274DEST_PATH_IMAGE024
其中,L fa 是用以区分节点是真实节点还是生成器生成的节点的交叉熵损失;L cl 是 用以区分节点是少数类还是多数类的交叉熵损失,将节点数最多的一组类别作为多数类, 其他类别为少数类,且保留原始的类别信息,减号表示希望数据尽可能远离多数类;L dis 是 用以扩大不同的类节点之间的嵌入距离的损失函数;
Figure 58324DEST_PATH_IMAGE025
表示第i个节点在学习 到表示向量z i 时属于伪类fake的概率
Figure 721387DEST_PATH_IMAGE026
,z i 表示第i个节点的最终向量表示,minority表示 节点属于少数类,
Figure 718817DEST_PATH_IMAGE015
表示第i个节点的预测标签,q i 表示第i个节点的真实标签,q j 表示第j个 节点的真实标签,p(z i )表示第i个节点属于少数类的概率,n g 表示生成的节点集,n min 表示少 数类节点集,n maj 表示多数类节点集。
基于生成对抗网络的图数据类别不平衡分类装置,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现所述的基于生成对抗网络的图数据类别不平衡分类方法。
本发明的优势和有益效果在于:
本发明的基于生成对抗网络的图数据类别不平衡分类方法及装置,通过高低阶构图,学习节点的局部和全局信息,并结合生成对抗思路,动态更新生成节点,有效解决了图数据不平衡问题,同时实验表明,本发明优于现有的SOTA方法。
附图说明
图1是本发明实施例中方法的流程图。
图2是本发明实施例中不平衡动态卷积生成对抗网络的原理图。
图3是本发明实施例中cora数据集的实验结果图。
图4a是本发明实施例中不考虑数据本身的不平衡问题时节点分类准确度示意图。
图4b是本发明实施例中考虑数据本身的不平衡问题时节点分类准确度示意图。
图5是本发明实施例中装置的结构示意图。
具体实施方式
以下结合附图对本发明的具体实施方式进行详细说明。应当理解的是,此处所描述的具体实施方式仅用于说明和解释本发明,并不用于限制本发明。
如图1、图2所示,基于生成对抗网络的图数据类别不平衡分类方法,包含如下步骤:
步骤S1:构建生成器。将属性图信息输入到生成器中,生成器包括低阶神经网络和高阶神经网络,学习图的局部和全局信息,学习并得到节点的嵌入表示向量Z,包括如下步骤:
步骤S1.1:提取图的空间结构:原始图G=(V,E),V表示节点集,E表示边集。根据图的节点和边信息,构造图的邻接矩阵AR N*N A用于表示图的拓扑结构特征;根据节点带有的属性信息,构造节点属性矩阵FR N*M ,其中,N表示节点的总数,M表示节点属性空间的总维度。
步骤S1.2:提取图的低阶信息:利用归纳式神经网络GraphSage学习节点的表示, 具体如下:首先初始化节点的表示
Figure 125528DEST_PATH_IMAGE001
,∀vV,然后聚合T跳邻居节点的表示
Figure 889084DEST_PATH_IMAGE002
t=1,2,3,…,T表示与邻居节点的相邻层数,mean表示 求取{·}数组平均值的函数,最后将v节点t-1层邻居节点u的信息拼接到t层节点v的向量 表示上,经过全连接层转换,得到节点vt层的向量表示,使用解码器重构图数据,形成仅 包含原始节点的重构邻接矩阵AD,具体地,图2中重构邻接矩阵AD,通过对节点的表示h及其 转置后的hT进行点积,再通过sigmoid激活函数得到。
步骤S1.3:提取图的高阶信息:该方法提出一种动态超图构建方法,通过不断迭代优化学习节点的高阶信息,并将其与低阶信息融合,以便于步骤S2生成优质的节点。具体如下:
步骤S1.3.1:首先使用GCN(Graph Convolutional Network,图卷积网络)学习节 点属性和拓扑信息,然后利用KNN(K-NearestNeighbor,K最邻近)算法构建原始超边,对每 个节点进行近邻计算,并将其与近邻构成基础超边的集合e b 。例如节点v经过计算,其超边 为
Figure 457469DEST_PATH_IMAGE003
,其中,x v 表示节点v经过GCN学习到的嵌入表示,x u 表示节点v的邻居 节点经过GCN学习到的嵌入表示。
步骤S1.3.2:通过K-means对X={x 1,x 2,…,x N }的节点嵌入表示进行聚类,学习到S 个聚类中心,计算每个节点到聚类中心的距离,然后将
Figure 1583DEST_PATH_IMAGE004
的聚类中心加入到e b 中,由于超 边是由多个节点组成的边,加入的聚类中心增加了节点数量,从而扩大了超边。
步骤S1.3.3:采用
Figure 680826DEST_PATH_IMAGE005
表示超边e包含的顶点集合,k e 表示超 边e包含的顶点个数,
Figure 361206DEST_PATH_IMAGE006
表示节点v包含的所有超边集合,k v 表示包含 节点v超边个数;对基础超图进行超图卷积,不断更新节点的表示,超图卷积如公式一所示, 先通过MLP(Multilayer Perceptron,多层感知机)学习到节点的转移状态矩阵,然后利用 一维的HGCN学习超边的向量表示,最后归一化超边的信息,将其聚合到节点v上,获得节点 最终的高阶表示h v
Figure 569333DEST_PATH_IMAGE007
Figure 335164DEST_PATH_IMAGE027
Figure 21360DEST_PATH_IMAGE028
(公式一)
Figure 373492DEST_PATH_IMAGE010
其中,T表示通过MLP学习到节点的转移状态矩阵,
Figure 486942DEST_PATH_IMAGE011
表示通过HGCN(Hyperbolic Graph Convolutional Neural Network,双曲图卷积神经网络)学习超边的向量表示,Wb 分别表示softmax激活函数的权重和偏置,w表示归一化后的超边信息,|Adj(v)|表示节点v 包含的所有超边的数量。
步骤S1.4:将学习到的高阶信息拼接到低阶信息中,并使用注意力机制提取两种表示中重要信息进行下一层传播,即最后的表示向量z v =cat[h v ,x v ],将其输入步骤S2中,生成新节点。
步骤S1.5:步骤S1.1至步骤S1.4构成的生成器,生成器的loss函数如公式二所示,同样由四部分组成,前两项是对生成的少数类数据的混淆鉴别器损失,L rf 用于判断节点是生成节点还是真实节点,通过损失训练使判别器将生成节点识别为真实节点,L maj 用于控制生成的少数类节点尽可能远离多数类。后两项是对生成数据的条件约束,L dis 的目的是将生成的少数类节点靠近真实的少数类节点,L rec 的目的是控制编码器学习图的真实信息;
Figure 740068DEST_PATH_IMAGE029
(公式二)
其中,
Figure 26693DEST_PATH_IMAGE030
表示第i个节点在学习到表示向量z i 时属于真实类的概率
Figure 619349DEST_PATH_IMAGE014
, z i 表示第i个节点的最终向量表示,real表示节点属于真实类,majority表示节点属于多数 类,
Figure 231596DEST_PATH_IMAGE015
表示第i个节点的预测标签,q i 表示第i个节点的真实标签,q j 表示第j个节点的真实标 签,p(z i )表示第i个节点属于少数类的概率,n g 表示生成的节点集,n min 表示少数类节点集,
Figure 706439DEST_PATH_IMAGE016
表示正则化,具体地,图2中L rec 即为
Figure 17DEST_PATH_IMAGE031
步骤S1.6:通过生成器和判别器动态更新模型参数,优化学习到的嵌入向量z i 和生成节点n g
步骤S2:对少样本图数据进行过采样。根据SMOTE原则,对学习到的每个少数类表示向量Z l 进行K近邻计算,选择其最近邻节点进行插值计算,生成新的节点N g ,其中K的取值取决于该类与多数类的不平衡比。
例如,计算少数类l中节点v的最近邻节点,
Figure 243917DEST_PATH_IMAGE017
s.t. l u =l v ,其中
Figure 967678DEST_PATH_IMAGE018
表示节点v属于类别l的表示向量,
Figure 929818DEST_PATH_IMAGE019
表示节点u属于类别l的表示向量,nn(v)表 示同一类别中距离节点v最近的邻节点,argmin||·||表示取距离最近操作;然后生成新的 合成节点
Figure 823824DEST_PATH_IMAGE020
Figure 125493DEST_PATH_IMAGE021
,δ表示平衡系数。
步骤S3:重建平衡图数据G’。通过已有图的节点和边信息训练边生成器,对生成的节点进行链路预测。
具体地,该方法使用点积操作进行边预测。节点uv的边概率是
Figure 79542DEST_PATH_IMAGE022
,边生成器的损失函数loss为
Figure 732240DEST_PATH_IMAGE023
,其中W是线性 函数的权重矩阵。当预测概率大于阈值0.5时,我们则认为节点uv存在边。通过不断地优 化学习,最终得到重构图的边信息。
步骤S4:将重建的平衡图数据G’作为判别器的输入,进行节点分类。在这里我们利用谱图神经网络GCN学习节点的嵌入表示,并结合softmax函数进行多类别分类。判别器的损失loss函数如公式三所示,分别是真实节点和生成节点是否是多数类,以及多数类和少数类的交叉熵函数。其中,第二项L cl 是因为该方法将节点数top1的类别当做多数类,其他均为少数类,且保留原始的类别信息。这里使用减号,是希望数据尽可能远离多数类。
Figure 164359DEST_PATH_IMAGE024
(公式三)
其中,L fa 是用以区分节点是真实节点还是生成器生成的节点的交叉熵损失;L cl 是 用以区分节点是少数类还是多数类的交叉熵损失,将节点数最多的一组类别作为多数类, 其他类别为少数类,且保留原始的类别信息,减号表示希望数据尽可能远离多数类;L dis 是 用以扩大不同的类节点之间的嵌入距离的损失函数;
Figure 586113DEST_PATH_IMAGE025
表示第i个节点在学习 到表示向量z i 时属于伪类fake的概率
Figure 711064DEST_PATH_IMAGE014
,z i 表示第i个节点的最终向量表示,minority表示 节点属于少数类,
Figure 585479DEST_PATH_IMAGE015
表示第i个节点的预测标签,q i 表示第i个节点的真实标签,q j 表示第j个 节点的真实标签,p(z i )表示第i个节点属于少数类的概率,n g 表示生成的节点集,n min 表示少 数类节点集,n maj 表示多数类节点集。
本发明实施例中,以cora数据集作为输入的图数据,来举例说明:
步骤S101,输入图数据。该数据集包含节点总数2708,节点特征总维度1433。节点的邻接矩阵是一个2708*2708维的矩阵,存储了每个节点的邻节点信息。节点属性矩阵F是一个2708*1433维矩阵,存储了每个节点的属性信息。根据节点对应的论文内容可将其分为7类,其中神经网络占比30.21%,遗传算法占比15.44%,概率学方法15.73%,理论12.96%,基于案例11%,增强学习8.01%,规则学习占比6.65%。
步骤S102,使用图卷积神经网络GraphSage学习节点和属性的融合信息,得到嵌入向量h
步骤S103,使用KNN和K-means初始化超图,然后使用超图卷积学习节点的高阶信息,并不断迭代更新超图,获得嵌入向量x
步骤S104,融合低阶表示h和高阶的嵌入表示x,得到节点的整体表示Z,根据SMOTE过采样规则,对少数类节点进行过采样,这里除了神经网络类别是多数类之外,其他6类均为少数类需要进行节点生成。例如初始类别训练集为[237,164,288,561,291,228,126],则我们需对其他节点进行补齐,得到最终的训练数据集为[561,561,561,561,561,561,561]。同时我们通过判别器反馈以及生成器损失loss进行更新。
步骤S105,将补齐之后的数据输入到判别器中,对模型进行训练,最终的分类结果如图3所示,与其他方法对比结果如表1所示。
表1对比实验结果表
Figure 290130DEST_PATH_IMAGE032
表中,基于Recall(召回率)评价指标、F1(F1值)评价指标、Auc(Area UnderCurve,曲线下面积)评价指标、AccAccuracy,准确率评价指标、Pre(精确率)评价指标,分别对GCN(图卷积神经网络)方法、SMOTE(合成少数类过采样技术)方法、GraphSMOTE(基于GNN的合成少数类过采样技术)方法和本发明的imGANSmote(基于生成对抗网络的图数据类别不平衡分类方法)方法进行评价的结果,根据实验结果可以看出,本发明的方法在图像分类精度上由于其他方法。
步骤S106,为了进一步确定该方法的有效性,我们进行了数据不平衡的消融实验。如图4a所示,imGANSmote当不考虑数据本身的不平衡问题时,节点分类准确度达到0.73(对角的准确度均值),但仔细分析结果会发现类别4、5、6的预测结果并不理想。如图4b所示,当考虑数据不平衡问题时,节点分类的正确性达到0.805,且每一类准确度都达到了0.69以上。
与前述基于生成对抗网络的图数据类别不平衡分类方法的实施例相对应,本发明还提供了基于生成对抗网络的图数据类别不平衡分类装置的实施例。
参见图5,本发明实施例提供的基于生成对抗网络的图数据类别不平衡分类装置,包括存储器和一个或多个处理器,存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现上述实施例中的基于生成对抗网络的图数据类别不平衡分类方法。
本发明基于生成对抗网络的图数据类别不平衡分类装置的实施例可以应用在任意具备数据处理能力的设备上,该任意具备数据处理能力的设备可以为诸如计算机等设备或装置。装置实施例可以通过软件实现,也可以通过硬件或者软硬件结合的方式实现。以软件实现为例,作为一个逻辑意义上的装置,是通过其所在任意具备数据处理能力的设备的处理器将非易失性存储器中对应的计算机程序指令读取到内存中运行形成的。从硬件层面而言,如图5所示,为本发明基于生成对抗网络的图数据类别不平衡分类装置所在任意具备数据处理能力的设备的一种硬件结构图,除了图5所示的处理器、内存、网络接口、以及非易失性存储器之外,实施例中装置所在的任意具备数据处理能力的设备通常根据该任意具备数据处理能力的设备的实际功能,还可以包括其他硬件,对此不再赘述。
上述装置中各个单元的功能和作用的实现过程具体详见上述方法中对应步骤的实现过程,在此不再赘述。
对于装置实施例而言,由于其基本对应于方法实施例,所以相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本发明方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
本发明实施例还提供一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现上述实施例中的基于生成对抗网络的图数据类别不平衡分类方法。
所述计算机可读存储介质可以是前述任一实施例所述的任意具备数据处理能力的设备的内部存储单元,例如硬盘或内存。所述计算机可读存储介质也可以是任意具备数据处理能力的设备的外部存储设备,例如所述设备上配备的插接式硬盘、智能存储卡(Smart Media Card,SMC)、SD卡、闪存卡(Flash Card)等。进一步的,所述计算机可读存储介质还可以既包括任意具备数据处理能力的设备的内部存储单元也包括外部存储设备。所述计算机可读存储介质用于存储所述计算机程序以及所述任意具备数据处理能力的设备所需的其他程序和数据,还可以用于暂时地存储已经输出或者将要输出的数据。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的范围。

Claims (10)

1.一种基于生成对抗网络的图数据类别不平衡分类方法,其特征在于包括如下步骤:
步骤S1:构建生成器;将属性图信息输入到生成器中,生成器包括低阶神经网络和高阶神经网络,学习图的局部和全局信息,学习并得到节点的嵌入表示向量;
步骤S2:对少样本图数据进行过采样;根据合成少数类过采样技术SMOTE原则,对学习到的每个少数类表示向量进行近邻计算,选择其最近邻节点进行插值计算,生成新的节点;
步骤S3:重建平衡图数据;通过已有图的节点和边信息训练边生成器,对生成的节点进行链路预测;
步骤S4:将重建的平衡图数据作为判别器的输入,进行节点分类。
2.根据权利要求1所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤1包括如下步骤:
步骤S1.1:提取图的空间结构;
步骤S1.2:提取图的低阶信息;
步骤S1.3:提取图的高阶信息;
步骤S1.4:将学习到的高阶信息拼接到低阶信息中,得到最后的表示向量;
步骤S1.5:步骤S1.1至步骤S1.4构成的生成器,包括对生成的少数类数据的混淆鉴别器损失和对生成数据的条件约束;
步骤S1.6:通过生成器和判别器动态更新参数,优化学习到的嵌入表示向量和生成节点。
3.根据权利要求2所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在 于:所述步骤S1.2中,首先初始化节点的表示
Figure 596508DEST_PATH_IMAGE001
,∀vVV表示节点集,然后聚合T跳邻居节 点的表示
Figure 451332DEST_PATH_IMAGE002
t=1,2,3,…,T表示与邻居节点的相邻层数, mean表示求取{·}数组平均值的函数,最后将v节点t-1层邻居节点u的信息拼接到t层节点v的向量表示上,得到节点vt层的向量表示。
4.根据权利要求2所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S1.1中,根据节点带有的属性信息,构造节点属性矩阵;所述步骤S1.3包括如下步骤:
步骤S1.3.1:首先使用图卷积网络GCN学习节点属性和拓扑信息,然后利用K最邻近算法构建原始超边,对每个节点进行近邻计算,并将其与近邻构成基础超边的集合;
步骤S1.3.2:通过K均值聚类算法K-means对节点嵌入表示进行聚类,学习到聚类中心,计算每个节点到聚类中心的距离,然后将聚类中心加入到基础超边的集合中;
步骤S1.3.3:对基础超图进行超图卷积,不断更新节点的表示,超图卷积,先通过多层感知机MLP学习到节点的转移状态矩阵,然后利用一维的双曲图卷积神经网络HGCN学习超边的向量表示,最后归一化超边的信息,将其聚合到节点上,获得节点最终的高阶表示。
5.根据权利要求4所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S1.3.3中,超图卷积公式如下:
Figure 434331DEST_PATH_IMAGE003
Figure 501644DEST_PATH_IMAGE004
Figure 250770DEST_PATH_IMAGE005
Figure 10916DEST_PATH_IMAGE006
其中,x u 表示节点v的邻居节点经过图卷积网络GCN学习到的嵌入表示,T表示通过多层 感知机MLP学习到节点的转移状态矩阵,
Figure 481211DEST_PATH_IMAGE007
表示通过双曲图卷积神经网络HGCN学习超边的 向量表示,Wb分别表示softmax激活函数的权重和偏置,w表示归一化后的超边信息,|Adj (v)|表示节点v包含的所有超边的数量,h v 表示节点最终的高阶表示。
6.根据权利要求2所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S1.1中,根据图的节点和边信息,构造图的邻接矩阵A;所述步骤S1.2中,使用解码器重构图数据,形成仅包含原始节点的重构邻接矩阵AD;所述步骤S1.4中,使用注意力机制提取两种表示中重要信息进行下一层传播,即最后的表示向量z v =cat[h v ,x v ],cat表示拼接操作,h v 表示高阶信息,x v 表示经步骤S1.2得到的节点v经过图卷积网络GCN学习到的嵌入表示;
所述步骤S1.5中,对生成的少数类数据的混淆鉴别器损失,包括L rf 用于判断节点是生成节点还是真实节点,通过损失训练使判别器将生成节点识别为真实节点,L maj 用于控制生成的少数类节点尽可能远离多数类;对生成数据的条件约束,包括L dis 用于将生成的少数类节点靠近真实的少数类节点,L rec 用于控制编码器学习图的真实信息;
Figure 148953DEST_PATH_IMAGE008
其中,
Figure 21094DEST_PATH_IMAGE009
表示第i个节点在学习到表示向量z i 时属于真实类的概率
Figure 217720DEST_PATH_IMAGE010
,z i 表示第i个节点的最终向量表示,real表示节点属于真实类,majority表示节点属于多数 类,
Figure 175312DEST_PATH_IMAGE011
表示第i个节点的预测标签,q i 表示第i个节点的真实标签,q j 表示第j个节点的真实标 签,p(z i )表示第i个节点属于少数类的概率,n g 表示生成的节点集,n min 表示少数类节点集,
Figure 115586DEST_PATH_IMAGE012
表示正则化。
7.根据权利要求1所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在 于:所述步骤S2中,计算少数类l中节点v的最近邻节点,
Figure 107813DEST_PATH_IMAGE013
s.t. l u =l v ,其中
Figure 740919DEST_PATH_IMAGE014
表示节点v属于类别l的表示向量,
Figure 185807DEST_PATH_IMAGE015
表示节点u属于类别l的表示向量,nn(v) 表示同一类别中距离节点v最近的邻节点,argmin||·||表示取距离最近操作;然后生成新 的合成节点
Figure 926843DEST_PATH_IMAGE016
Figure 773576DEST_PATH_IMAGE017
,δ表示平衡系数。
8.根据权利要求7所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在 于:所述步骤S3中,使用点积操作进行边预测,节点uv的边概率是
Figure 312005DEST_PATH_IMAGE018
,边生成器的损失函数loss为
Figure 509768DEST_PATH_IMAGE019
,其中W是 线性函数softmax的权重矩阵,E表示图的边集,A表示根据图的节点和边信息构造的图的邻 接矩阵,当预测概率大于阈值时,认为节点uv存在边,通过不断地优化学习,最终得到重 构图的边信息。
9.根据权利要求1所述的基于生成对抗网络的图数据类别不平衡分类方法,其特征在于:所述步骤S4中,利用谱图神经网络GCN学习节点的嵌入表示,并结合softmax函数进行多类别分类;判别器的损失loss函数,分别是真实节点和生成节点是否是多数类,以及多数类和少数类的交叉熵函数,具体如下:
Figure 791845DEST_PATH_IMAGE021
其中,L fa 是用以区分节点是真实节点还是生成器生成的节点的交叉熵损失;L cl 是用以 区分节点是少数类还是多数类的交叉熵损失,将节点数最多的一组类别作为多数类,其他 类别为少数类,且保留原始的类别信息,减号表示希望数据尽可能远离多数类;L dis 是用以 扩大不同的类节点之间的嵌入距离的损失函数;
Figure 227505DEST_PATH_IMAGE022
表示第i个节点在学习到表 示向量z i 时属于伪类fake的概率
Figure 936835DEST_PATH_IMAGE023
,z i 表示第i个节点的最终向量表示,minority表示节点 属于少数类,
Figure 825157DEST_PATH_IMAGE011
表示第i个节点的预测标签,q i 表示第i个节点的真实标签,q j 表示第j个节点 的真实标签,p(z i )表示第i个节点属于少数类的概率,n g 表示生成的节点集,n min 表示少数类 节点集,n maj 表示多数类节点集。
10.一种基于生成对抗网络的图数据类别不平衡分类装置,其特征在于,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现权利要求1-9中任一项所述的基于生成对抗网络的图数据类别不平衡分类方法。
CN202211461517.1A 2022-11-17 2022-11-17 基于生成对抗网络的图数据类别不平衡分类方法及装置 Active CN115661550B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211461517.1A CN115661550B (zh) 2022-11-17 2022-11-17 基于生成对抗网络的图数据类别不平衡分类方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211461517.1A CN115661550B (zh) 2022-11-17 2022-11-17 基于生成对抗网络的图数据类别不平衡分类方法及装置

Publications (2)

Publication Number Publication Date
CN115661550A true CN115661550A (zh) 2023-01-31
CN115661550B CN115661550B (zh) 2023-05-30

Family

ID=85018043

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211461517.1A Active CN115661550B (zh) 2022-11-17 2022-11-17 基于生成对抗网络的图数据类别不平衡分类方法及装置

Country Status (1)

Country Link
CN (1) CN115661550B (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116628538A (zh) * 2023-07-26 2023-08-22 之江实验室 基于图对齐神经网络的患者聚类方法、装置和计算机设备
CN116721441A (zh) * 2023-08-03 2023-09-08 厦门瞳景智能科技有限公司 基于区块链的门禁安全管理方法与系统
CN116936108A (zh) * 2023-09-19 2023-10-24 之江实验室 一种面向不平衡数据的疾病预测系统
CN117910519A (zh) * 2024-03-20 2024-04-19 烟台大学 进化图生成对抗网络的图应用方法、系统及推荐方法
CN117910519B (zh) * 2024-03-20 2024-06-07 烟台大学 进化图生成对抗网络的推荐方法

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
GB201910720D0 (en) * 2019-07-26 2019-09-11 Tomtom Global Content Bv Generative adversarial Networks for image segmentation
WO2020163970A1 (en) * 2019-02-15 2020-08-20 Surgical Safety Technologies Inc. System and method for adverse event detection or severity estimation from surgical data
CN111597887A (zh) * 2020-04-08 2020-08-28 北京大学 一种行人再识别方法及系统
CN115130509A (zh) * 2022-06-29 2022-09-30 哈尔滨工业大学(威海) 基于条件式变分自编码器的心电信号生成方法

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2020163970A1 (en) * 2019-02-15 2020-08-20 Surgical Safety Technologies Inc. System and method for adverse event detection or severity estimation from surgical data
GB201910720D0 (en) * 2019-07-26 2019-09-11 Tomtom Global Content Bv Generative adversarial Networks for image segmentation
CN111597887A (zh) * 2020-04-08 2020-08-28 北京大学 一种行人再识别方法及系统
CN115130509A (zh) * 2022-06-29 2022-09-30 哈尔滨工业大学(威海) 基于条件式变分自编码器的心电信号生成方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
KARRAS T.等: "Analyzing and improving the image quality of stylegan" *
俞彬: "基于生成对抗网络的图像类别不平衡问题数据扩充方法" *

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116628538A (zh) * 2023-07-26 2023-08-22 之江实验室 基于图对齐神经网络的患者聚类方法、装置和计算机设备
CN116721441A (zh) * 2023-08-03 2023-09-08 厦门瞳景智能科技有限公司 基于区块链的门禁安全管理方法与系统
CN116721441B (zh) * 2023-08-03 2024-01-19 厦门瞳景智能科技有限公司 基于区块链的门禁安全管理方法与系统
CN116936108A (zh) * 2023-09-19 2023-10-24 之江实验室 一种面向不平衡数据的疾病预测系统
CN116936108B (zh) * 2023-09-19 2024-01-02 之江实验室 一种面向不平衡数据的疾病预测系统
CN117910519A (zh) * 2024-03-20 2024-04-19 烟台大学 进化图生成对抗网络的图应用方法、系统及推荐方法
CN117910519B (zh) * 2024-03-20 2024-06-07 烟台大学 进化图生成对抗网络的推荐方法

Also Published As

Publication number Publication date
CN115661550B (zh) 2023-05-30

Similar Documents

Publication Publication Date Title
Xin et al. Neurolkh: Combining deep learning model with lin-kernighan-helsgaun heuristic for solving the traveling salesman problem
He et al. AutoML: A survey of the state-of-the-art
CN110263227B (zh) 基于图神经网络的团伙发现方法和系统
Alzubaidi et al. A survey on deep learning tools dealing with data scarcity: definitions, challenges, solutions, tips, and applications
CN115661550A (zh) 基于生成对抗网络的图数据类别不平衡分类方法及装置
KR102295805B1 (ko) 학습 데이터 관리 방법
CN112115998B (zh) 一种基于对抗增量聚类动态路由网络克服灾难性遗忘的方法
Joy et al. Batch Bayesian optimization using multi-scale search
CN112990280A (zh) 面向图像大数据的类增量分类方法、系统、装置及介质
KR102285530B1 (ko) 영상 정합을 위한 이미지 처리 방법
Du et al. Polyline simplification based on the artificial neural network with constraints of generalization knowledge
KR20220000387A (ko) 기계 학습 기반 기상 자료 생성 장치 및 방법
Wankhade et al. Data stream classification: a review
Li et al. Automatic design of machine learning via evolutionary computation: A survey
Hong et al. Variational gridded graph convolution network for node classification
Qu et al. Effects of loss function and data sparsity on smooth manifold extraction with deep model
CN116524282B (zh) 一种基于特征向量的离散相似度匹配分类方法
Gao et al. Multi-objective pointer network for combinatorial optimization
CN113297385B (zh) 基于改进GraphRNN的多标签文本分类系统及分类方法
KR102437396B1 (ko) 모델 학습 방법
Guo et al. End-to-end variational graph clustering with local structural preservation
Huang et al. Building hierarchical class structures for extreme multi-class learning
CN115331754A (zh) 基于哈希算法的分子分类方法
Jiao et al. Scalable self-supervised graph representation learning via enhancing and contrasting subgraphs
KR20210050413A (ko) 비정상 데이터 생성 방법

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant