CN114118416A - 一种基于多任务学习的变分图自动编码器方法 - Google Patents
一种基于多任务学习的变分图自动编码器方法 Download PDFInfo
- Publication number
- CN114118416A CN114118416A CN202111502928.6A CN202111502928A CN114118416A CN 114118416 A CN114118416 A CN 114118416A CN 202111502928 A CN202111502928 A CN 202111502928A CN 114118416 A CN114118416 A CN 114118416A
- Authority
- CN
- China
- Prior art keywords
- node
- graph
- matrix
- shallow
- inputting
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种基于多任务学习的变分图自动编码器方法,包括如下步骤:S1:对源数据进行预处理;S2:划分图数据集;S3:将S22获得的训练集输入浅层图卷积层获得浅层共享嵌入表示H;S4:将S3获得的浅层共享嵌入表示H分别输入两个不同的下游网络框架,获得各自的嵌入表示;S5:将S4获得的两个不同的嵌入表示分别进行链路预测任务和半监督节点分类任务。这种方法能使嵌入表示跟样本空间的真实分布更相近,在链路预测任务上具有很强的竞争力,鲁棒性强。
Description
技术领域
本发明涉及计算机数据分析领域,具体是一种基于多任务学习的变分图自动编码器方法。
背景技术
随着深度学习技术的不断发展,越来越多复杂的应用场景不能使用简单的欧几里得数据进行表示,例如分子结构、推荐系统、引文网络和社交网络等。这些应用数据,即非欧几里得数据,可以使用图来表示。图数据中包括节点和边,节点具有自己的属性特征,且不同的节点具有不同数量的邻居节点。传统的卷积神经网络或者循环神经网络不能够用于表示图数据。近年来,图神经网络吸引了研究者们极大的注意力,相对于卷积神经网络和循环神经网络,图神经网络可以通过保留拓扑结构信息和节点特征信息,将节点特征嵌入到低维空间,具有很强的性能。其中,图自动编码器和变分图自动编码器是进行图无监督学习(链路预测、节点聚类、图生成)的有效框架。
然而,图数据的多任务学习并没有引起研究者们太多的关注。实际上,将多个相关任务放在一起学习可以提高任务的整体泛化能力。目前已有的基于多任务学习的图神经网络框架都是直接将学习到的共享表示作为下游任务的输入,这意味着不同的下游任务使用共同的嵌入表示来进行学习,并没有着重去学习单个任务特定的嵌入信息。实际上,不同任务使用共同的嵌入表示可能不利于各自任务的学习,因为这样学习到的共享嵌入表示可能也学习到了其他任务的噪声。
发明内容
本发明的目的是针对现有技术的不足,而提供一种基于多任务学习的变分图自动编码器方法。这种方法能使嵌入表示跟样本空间的真实分布更相近,在链路预测任务上具有很强的竞争力,鲁棒性强。
实现本发明目的的技术方案是:
一种基于多任务学习的变分图自动编码器方法,包括如下步骤:
S1:对源数据进行预处理,具体过程为:
S11:将引文网络中的源数据处理成图数据G=(V,E),V为节点集,E为边集,假设引文网络中的一篇论文视为图中的一个节点,论文的作者、研究方向视为节点的特征,论文与被引用论文之间建立一条连接的无向边,论文所属的类别视为标签,由此一个引文网络构成一个图数据集;
S12:利用S11中获得的图数据集,得到图对应的度矩阵和邻接矩阵、特征矩阵;
S2:划分图数据集,具体过程为:
S21:将图数据集中部分数据进行掩码设置来进行半监督学习;
S22:将S12获得的矩阵中的数据划分为训练集、验证集和测试集;
S3:将S22获得的训练集输入浅层图卷积层获得浅层共享嵌入表示H,即将邻接矩阵和特征矩阵输入浅层图卷积层,通过消息传播机制H=σ(AXW),其中A是邻接矩阵,X是节点特征矩阵,W是可学习的参数矩阵,σ是激活函数,以聚合当前邻居节点的特征信息和拓扑结构信息来更新当前节点的特征信息,从而获得浅层共享嵌入表示H;
S4:将S3获得的浅层共享嵌入表示H分别输入两个不同的下游网络框架,获得各自的嵌入表示,具体过程为:
S41:将S3获得的浅层共享嵌入表示H输入用于链路预测的图卷积网络中,获得嵌入表示Z_mean和Z_log;
S42:将Z_mean和Z_log利用高斯分布进行相加,获得符合高斯分布的嵌入表示Z;
S43:将Z作为判别器的假样本输入,基于生成对抗机制,使得嵌入表示Z能够更接近原始的样本分布;
S44:将S3获得的浅层表示H输入用于节点分类的图卷积网络中,获得嵌入表示Z_nc;
S5:将S4获得的两个不同的嵌入表示分别进行链路预测任务和半监督节点分类任务,具体过程为:
S51:将S4获得的嵌入表示Z输入内积层进行邻接矩阵重构,用于链路预测任务;
S52:将S4获得的嵌入表示Z输入图卷积层进行特征矩阵重构,作为链路预测的辅助任务;
S53:将S4获得嵌入表示Z_nc输入用于节点分类的图卷积网络中;
S54:计算损失函数,利用梯度下降算法来更新迭代参数,经过多次迭代之后使得损失函数可以收敛,其中,最终的损失函数公式为:
其中,C是一组节点标签,如果节点i属于c类,y是节点所属的类别标签,是节点i属于类c的softmax概率,当节点i在有标签时MASKi=1,否则MASKi=0,Eq(Z|X,A)[logp(A|Z)]-KL[q(Z|X,A)||p(Z)]为邻接矩阵的重构损失,其中KL[q(·)||p(·)]是生成样本与原始样本的相对熵,为半监督节点分类的交叉熵损失,为特征矩阵的重构损失。
本技术方案的有益效果是:
本技术方案基于多任务联合学习无监督链路预测任务和半监督节点分类任务,不同于其他的基于多任务的图神经网络框架,直接使用共享表示作为不同预测或分类任务的输入,本技术方案仅在浅层获得共享表示,将共享表示分别输入到不同下游任务设计的专属网络框架,此外,为了使链路预测任务的嵌入表示更具鲁棒性,本技术方案加入了对抗生成网络框架,通过生成器-判别器的博弈机制使嵌入表示跟样本空间的真实分布更相近,在三个真实的引文网络数据集上的实验结果表明,本技术方案提出的框架在链路预测任务上具有很强的竞争力,甚至在其中一个数据集上优于最先进的方法。
这种方法能使嵌入表示跟样本空间的真实分布更相近,在链路预测任务上具有很强的竞争力,鲁棒性强。
附图说明
图1为实施例的流程示意图。
具体实施方式
下面结合附图及具体实施例对本发明作进一步的详细描述,但不是对本发明的限定。
实施例:
本例适用于非欧式空间的数据,例如:社交网络、引文网络和分子结构。
参照图1,一种基于多任务学习的变分图自动编码器方法,包括如下步骤:
S1:对源数据进行预处理,具体过程为:
S11:本例在引文网络中收集图数据集,每一个图数据集的类别数量不同,数据集中的每一篇论文都有自己的标签,将引文网络中的图数据集处理成图数据G=(V,E),V为节点集,E为边集,假设一篇论文视为图中的一个节点,论文的作者、研究方向视为节点的特征,论文与被引用论文之间建立一条连接的无向边,论文所属的类别视为标签,一个引文网络构成一个图数据集,三个引文网络图数据集详细情况如表1所示:
表1数据集
数据集 | 节点数量 | 边数量 | 特征维度 | 类别数 |
Cora | 2708 | 5429 | 1433 | 7 |
Citeseer | 3327 | 4732 | 3703 | 6 |
Pubmed | 19717 | 44338 | 500 | 3 |
;
S12:利用S11中获得的图数据集,得到图对应的度矩阵和邻接矩阵、特征矩阵;
S2:划分图数据集,具体过程为:
S21:将图数据集中部分数据进行掩码设置来进行半监督学习;
S22:将S12获得的矩阵中的数据划分为训练集、验证集和测试集;
S3:将S22获得的训练集输入浅层图卷积层获得浅层共享嵌入表示H,即将Cora=(A,X)输入浅层图卷积层,利用消息传播机制H=σ(AXW),其中A是邻接矩阵,X是节点特征矩阵,W是可学习的参数矩阵,σ是激活函数,以聚合当前邻居节点的特征信息和拓扑结构信息来更新当前节点的特征信息,来获得浅层共享嵌入表示H,图卷积网络的公式为:
其中,激活函数σ(·)=ReLU(·),W为权重矩阵,D是邻接矩阵的度矩阵;
S4:将S3获得的浅层共享嵌入表示H分别输入两个不同的下游网络框架,获得各自的嵌入表示,具体过程为:
S41:将S3获得的浅层共享嵌入表示H输入用于链路预测的图卷积网络中,获得嵌入表示Z_mean和Z_log;
S42:将Z_mean和Z_log利用高斯分布进行相加,获得符合高斯分布的嵌入表示Z;
S43:将Z作为判别器的假样本输入,基于生成对抗机制,使得嵌入表示Z能够更接近原始的样本分布;
S44:将S3获得的浅层表示H输入用于节点分类的图卷积网络中,获得嵌入表示Z_nc;
S5:将S4获得的两个不同的嵌入表示分别进行链路预测任务和半监督节点分类任务,具体过程为:
S51:将S4获得的嵌入表示Z输入内积层进行邻接矩阵重构,用于链路预测任务,得到重构的邻接矩阵:
重构邻接矩阵的损失函数为:
Lre=Eq(Z|X,A)[logp(A|Z)]-KL[q(Z|X,A)||p(Z)];
S52:将S4获得的嵌入表示Z输入图卷积层进行特征矩阵重构,作为链路预测的辅助任务,获得重构的特征矩阵:
重构特征矩阵的损失函数为:
S53:将S4获得嵌入表示Z_nc输入用于节点分类的图卷积网络中,点分类任务的损失函数为:
S54:计算损失函数,利用梯度下降算法来更新迭代参数,经过多次迭代之后使得损失函数可以收敛,其中,最终的损失函数公式为:
其中,C是一组节点标签,如果节点i属于c类,y是节点所属的类别标签,是节点i属于类c的softmax概率,当节点i在有标签时MASKi=1,否则MASKi=0,Eq(Z|X,A)[logp(A|Z)]-KL[q(Z|X,A)||p(Z)]为邻接矩阵的重构损失,其中KL[q(·)||p(·)]是生成样本与原始样本的相对熵,为半监督节点分类的交叉熵损失,为特征矩阵的重构损失。
在迭代50次之后,损失函数已经趋于收敛,停止训练。
三个图数据集的实验结果如表2、表3所示:
表2链路预测:AUC和AP评分
表3节点分类:准确率
Methods | Cora | Pubmed | Citeseer |
GCN | 0.815 | 0.790 | 0.703 |
Planetoid | 0.757 | 0.772 | 0.947 |
DeepWalk | 0.972 | 0.653 | 0.432 |
MTGAE | 0.790 | 0.804 | 0.718 |
本例 | 0.809 | 0.861 | 0.666 |
。
Claims (1)
1.一种基于多任务学习的变分图自动编码器方法,其特征在于,包括如下步骤:
S1:对源数据进行预处理,具体过程为:
S11:将引文网络中的源数据处理成图数据G=(V,E),V为节点集,E为边集,假设引文网络中的一篇论文视为图中的一个节点,论文的作者、研究方向视为节点的特征,论文与被引用论文之间建立一条连接的无向边,论文所属的类别视为标签,由此一个引文网络构成一个图数据集;
S12:利用S11中获得的图数据集,得到图对应的度矩阵和邻接矩阵、特征矩阵;
S2:划分图数据集,具体过程为:
S21:将图数据集中部分数据进行掩码设置来进行半监督学习;
S22:将S12获得的矩阵中的数据划分为训练集、验证集和测试集;
S3:将S22获得的训练集输入浅层图卷积层获得浅层共享嵌入表示H,即将邻接矩阵和特征矩阵输入浅层图卷积层,通过消息传播机制H=σ(AXW),其中A是邻接矩阵,X是节点特征矩阵,W是可学习的参数矩阵,σ是激活函数,以聚合当前邻居节点的特征信息和拓扑结构信息来更新当前节点的特征信息,从而获得浅层共享嵌入表示H;
S4:将S3获得的浅层共享嵌入表示H分别输入两个不同的下游网络框架,获得各自的嵌入表示,具体过程为:
S41:将S3获得的浅层共享嵌入表示H输入用于链路预测的图卷积网络中,获得嵌入表示Z_mean和Z_log;
S42:将Z_mean和Z_log利用高斯分布进行相加,获得符合高斯分布的嵌入表示Z;
S43:将Z作为判别器的假样本输入,基于生成对抗机制,使得嵌入表示Z能够更接近原始的样本分布;
S44:将S3获得的浅层表示H输入用于节点分类的图卷积网络中,获得嵌入表示Z_nc;S5:将S4获得的两个不同的嵌入表示分别进行链路预测任务和半监督节点分类任务,具体过程为:
S51:将S4获得的嵌入表示Z输入内积层进行邻接矩阵重构,用于链路预测任务;
S52:将S4获得的嵌入表示Z输入图卷积层进行特征矩阵重构,作为链路预测的辅助任务;
S53:将S4获得嵌入表示Z_nc输入用于节点分类的图卷积网络中;
S54:计算损失函数,利用梯度下降算法来更新迭代参数,经过多次迭代之后使得损失函数可以收敛,其中,最终的损失函数公式为:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111502928.6A CN114118416A (zh) | 2021-12-09 | 2021-12-09 | 一种基于多任务学习的变分图自动编码器方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111502928.6A CN114118416A (zh) | 2021-12-09 | 2021-12-09 | 一种基于多任务学习的变分图自动编码器方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114118416A true CN114118416A (zh) | 2022-03-01 |
Family
ID=80363953
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111502928.6A Pending CN114118416A (zh) | 2021-12-09 | 2021-12-09 | 一种基于多任务学习的变分图自动编码器方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114118416A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117112866A (zh) * | 2023-10-23 | 2023-11-24 | 人工智能与数字经济广东省实验室(广州) | 基于图表示学习的社交网络节点迁移可视化方法及系统 |
CN117633699A (zh) * | 2023-11-24 | 2024-03-01 | 成都理工大学 | 基于三元互信息图对比学习的网络节点分类算法 |
CN117972497A (zh) * | 2024-04-01 | 2024-05-03 | 中国传媒大学 | 基于多视图特征分解的虚假信息检测方法及系统 |
-
2021
- 2021-12-09 CN CN202111502928.6A patent/CN114118416A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117112866A (zh) * | 2023-10-23 | 2023-11-24 | 人工智能与数字经济广东省实验室(广州) | 基于图表示学习的社交网络节点迁移可视化方法及系统 |
CN117633699A (zh) * | 2023-11-24 | 2024-03-01 | 成都理工大学 | 基于三元互信息图对比学习的网络节点分类算法 |
CN117633699B (zh) * | 2023-11-24 | 2024-06-07 | 成都理工大学 | 基于三元互信息图对比学习的网络节点分类算法 |
CN117972497A (zh) * | 2024-04-01 | 2024-05-03 | 中国传媒大学 | 基于多视图特征分解的虚假信息检测方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114118416A (zh) | 一种基于多任务学习的变分图自动编码器方法 | |
CN112231562B (zh) | 一种网络谣言识别方法及系统 | |
CN112395466B (zh) | 一种基于图嵌入表示和循环神经网络的欺诈节点识别方法 | |
CN113065974B (zh) | 一种基于动态网络表示学习的链路预测方法 | |
CN115661550B (zh) | 基于生成对抗网络的图数据类别不平衡分类方法及装置 | |
CN113157957A (zh) | 一种基于图卷积神经网络的属性图文献聚类方法 | |
CN109711411B (zh) | 一种基于胶囊神经元的图像分割识别方法 | |
Qiu et al. | An adaptive social spammer detection model with semi-supervised broad learning | |
CN112685504A (zh) | 一种面向生产过程的分布式迁移图学习方法 | |
CN109446414A (zh) | 一种基于神经网络分类的软件信息站点快速标签推荐方法 | |
CN111861756A (zh) | 一种基于金融交易网络的团伙检测方法及其实现装置 | |
CN116152554A (zh) | 基于知识引导的小样本图像识别系统 | |
CN113822419A (zh) | 一种基于结构信息的自监督图表示学习运行方法 | |
CN110991603B (zh) | 一种神经网络的局部鲁棒性验证方法 | |
CN116307212A (zh) | 一种新型空气质量预测方法及系统 | |
Yuan et al. | Community detection with graph neural network using Markov stability | |
CN113836319B (zh) | 融合实体邻居的知识补全方法及系统 | |
CN110830291A (zh) | 一种基于元路径的异质信息网络的节点分类方法 | |
CN117272195A (zh) | 基于图卷积注意力网络的区块链异常节点检测方法及系统 | |
Yin et al. | Ncfm: Accurate handwritten digits recognition using convolutional neural networks | |
CN116628524A (zh) | 一种基于自适应图注意力编码器的社区发现方法 | |
Marconi et al. | Hyperbolic manifold regression | |
WO2022227957A1 (zh) | 一种基于图自编码器的融合子空间聚类方法及系统 | |
CN113191144B (zh) | 一种基于传播影响力的网络谣言识别系统及方法 | |
McDonald et al. | Hyperbolic embedding of attributed and directed networks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |