CN112861936B - 一种基于图神经网络知识蒸馏的图节点分类方法及装置 - Google Patents

一种基于图神经网络知识蒸馏的图节点分类方法及装置 Download PDF

Info

Publication number
CN112861936B
CN112861936B CN202110102108.1A CN202110102108A CN112861936B CN 112861936 B CN112861936 B CN 112861936B CN 202110102108 A CN202110102108 A CN 202110102108A CN 112861936 B CN112861936 B CN 112861936B
Authority
CN
China
Prior art keywords
node
label
model
prediction result
soft
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110102108.1A
Other languages
English (en)
Other versions
CN112861936A (zh
Inventor
杨成
石川
刘佳玮
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing University of Posts and Telecommunications
Original Assignee
Beijing University of Posts and Telecommunications
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing University of Posts and Telecommunications filed Critical Beijing University of Posts and Telecommunications
Priority to CN202110102108.1A priority Critical patent/CN112861936B/zh
Publication of CN112861936A publication Critical patent/CN112861936A/zh
Application granted granted Critical
Publication of CN112861936B publication Critical patent/CN112861936B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q50/00Information and communication technology [ICT] specially adapted for implementation of business processes of specific business sectors, e.g. utilities or tourism
    • G06Q50/10Services
    • G06Q50/20Education
    • G06Q50/205Education administration or guidance
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Business, Economics & Management (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Computational Linguistics (AREA)
  • Educational Administration (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Tourism & Hospitality (AREA)
  • Strategic Management (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Educational Technology (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Economics (AREA)
  • Human Resources & Organizations (AREA)
  • Marketing (AREA)
  • Primary Health Care (AREA)
  • General Business, Economics & Management (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明实施例提供了一种基于图神经网络知识蒸馏的图节点分类方法及装置,将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,通过使用教师模型的第一预测结果优化学生模型的,不需要教师模型和学生模型之间的集成或迭代,也能够提高分类精度,简化了优化学生模型的过程;学生模型采用标签传播公式,通过有标签的节点传播到相邻无标签的节点,可以从图结构的先验知识中受益;也通过特征变换公式,预测无标签点集的软标签,可以从特征的先验知识中受益,即数据集中具有硬标签的点集及无标签点集的特征。从而充分的使用先验知识,提高学生模型的分类效果,从而提高分类精度。

Description

一种基于图神经网络知识蒸馏的图节点分类方法及装置
技术领域
本发明涉及机器学习技术领域,特别是涉及一种基于图神经网络知识蒸馏的图节点分类方法及装置。
背景技术
知识蒸馏被提出用于模型压缩,其中训练一个小的轻量级学生模型来模仿预训练的大教师模型的软预测。教师模型中的知识经过蒸馏以后,会转化为学生模型中的知识。
目前,知识蒸馏与图卷积网络(Graph convolutional networks,简称GCN)结合在一起进行应用。比如,可靠数据蒸馏(Resilient Distributed Dataset,简称RDD)用相同的体系结构,训练了多个GCN学生,然后对这些GCN学生进行集成以获得更好的性能。这样两个接收大小不同的GCN相互学习。其中,教师模型和学生模型,都是GCN。
但是,目前使用GCN作为学生模型的分类效果较差。
发明内容
本发明实施例的目的在于提供一种基于图神经网络知识蒸馏的图节点分类方法及装置,用以解决现有技术中使用GCN作为学生模型的分类效果较差的技术问题。具体技术方案如下:
第一方面,本发明实施例提供了基于图神经网络知识蒸馏的图节点分类方法,包括:
获取图神经网络教师模型,数据集,以及图结构中的节点集和边集,其中,所述教师模型为图神经网络GNN分类器,并且通过所述图结构及所述数据集预训练得到的,所述教师模型用于预测无标签点集的软标签可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果;所述数据集包括:具有硬标签的点集及无标签点集,所述每个节点表示数据集中的一个点;所述点为基准引文或购买商品;
基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果;所述学生模型是基于所有硬标签的分布,采用标签传播公式,预测有标签的节点传播到相邻无标签的节点的概率分布,以及基于每个节点的特征,采用特征变换公式,预测无标签点集的软标签进行组合训练的;
将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;所述训练好的学生模型用于对无标签点集的软标签的预测,得到基准引文的类型或购买商品的类型。
进一步的,在所述基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果之前,所述方法还包括:
根据如下公式:
Figure BDA0002916368350000021
为每个节点初始化标签预测,以完成初始化学生模型;
其中,
Figure BDA0002916368350000022
为节点v在第k次迭代中的预测概率分布,
Figure BDA0002916368350000023
为节点v在初始化的预测概率分布,LP为标签传播的英文简称,∈为属于,
Figure BDA0002916368350000024
为实数集合,Y为标签集合,|.|为集合的基数,
Figure BDA0002916368350000025
为任意,VL为有标签节点集合,V为所有节点,L为有标签,VU为无标签节点集合,U为无标签,fLP为所述第一预测结果中的LP的最终预测。
进一步的,所述采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果,包括:
基于每个节点的特征,根据如下公式:
cu=zTXu,将每个节点的特征映射为所述每个节点的置信度;
其中,cu为所述每个节点的置信度,u为任一节点名称,z为可学习参数,
Figure BDA0002916368350000026
是一个可学习参数,∈为属于,
Figure BDA0002916368350000027
为d维实数集合,d为维度大小,T为转置,Xu为任一节点u的特征;
根据如下公式:
Figure BDA0002916368350000028
为每条边计算边权,其中,所述边为每两个节点之间的边;
其中,wuv为每两个节点u与节点v之间的边的边权,exp为指数函数运算符,cu为节点u的置信度,u′为从集合Nv中选取的任一节点,Nv为节点v和节点v的邻居组成的集合,∈为属于,∪为并集,{v}为节点v的集合,cu′为节点u′的置信度;
针对所有节点,根据如下特征变换公式:
fFT(v)=softmax(MLP(Xv)),将每个节点的特征变换为所述无标签点集的软标签的预测;
其中,fFT(v)为特征变换函数,softmax为归一化函数,FT为特征变换,MLP(.)为多层感知器;
根据如下公式:
Figure BDA0002916368350000031
更新每一层所述无标签点集的软标签的预测,得到所述无标签点集的软标签的第二预测结果;
其中,
Figure BDA0002916368350000032
作为一个整体,表示标签传播公式,
Figure BDA0002916368350000033
为对节点v执行第k+1层CPF函数,αv为平衡参数,u为节点名称,Nv为节点v的邻居,∈为属于,v为任一节点名称,∪为并集运算,{v}为节点v构成的集合,
Figure BDA0002916368350000034
为对节点u执行第k层CPF函数,fFT(v)为对节点v执行FT函数。
进一步的,所述将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型,包括:
将所述第一预测结果代入如下公式:
Figure BDA0002916368350000035
优化所述学生模型的无标签点集的软标签的预测,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;
其中,K为所述学生模型的总层数,min为最小化,Θ为参数集合,fGNN(v)为节点v的教师模型软标签,
Figure BDA0002916368350000036
为节点v的第k层CPF模型,CPF;Θ为模型及其参数,;为分隔符,‖.‖2为L2范数。
进一步的,采用如下步骤,所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果:
获取数据集,以及图结构中的节点集和边集;
通过所述图结构及所述数据集预训练,得到图神经网络教师模型;所述教师模型为图神经网络GNN分类器;
采用所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果。
第二方面,本发明实施例提供了基于图神经网络知识蒸馏的图节点分类装置,包括:
获取模块,用于获取图神经网络教师模型,数据集,以及图结构中的节点集和边集,其中,所述教师模型为图神经网络GNN分类器,并且通过所述图结构及所述数据集预训练得到的,所述教师模型用于预测无标签点集的软标签可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果;所述数据集包括:具有硬标签的点集及无标签点集,所述每个节点表示数据集中的一个点;所述点为基准引文或购买商品;
第一处理模块,用于基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果;所述学生模型是基于所有硬标签的分布,采用标签传播公式,预测有标签的节点传播到相邻无标签的节点的概率分布,以及基于每个节点的特征,采用特征变换公式,预测无标签点集的软标签进行组合训练的;
第二处理模块,用于将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;所述训练好的学生模型用于对无标签点集的软标签的预测,得到基准引文的类型或购买商品的类型。
进一步的,所述装置还包括:初始化模块,用于:
在所述基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果之前,根据如下公式:
Figure BDA0002916368350000041
为每个节点初始化标签预测,以完成初始化学生模型;
其中,
Figure BDA0002916368350000042
为节点v在第k次迭代中的预测概率分布,
Figure BDA0002916368350000043
为节点v在初始化的预测概率分布,LP为标签传播的英文简称,∈为属于,
Figure BDA0002916368350000044
为实数集合,Y为标签集合,|.|为集合的基数,
Figure BDA0002916368350000045
为任意,VL为有标签节点集合,V为所有节点,L为有标签,VU为无标签节点集合,U为无标签,fLP为所述第一预测结果中的LP的最终预测。
进一步的,所述第一处理模块,用于:
基于每个节点的特征,根据如下公式:
cu=zTXu,将每个节点的特征映射为所述每个节点的置信度;
其中,cu为所述每个节点的置信度,u为任一节点名称,z为可学习参数,
Figure BDA0002916368350000046
是一个可学习参数,∈为属于,
Figure BDA0002916368350000047
为d维实数集合,d为维度大小,T为转置,Xu为任一节点u的特征;
根据如下公式:
Figure BDA0002916368350000051
为每条边计算边权,其中,所述边为每两个节点之间的边;
其中,wuv为每两个节点u与节点v之间的边的边权,exp为指数函数运算符,cu为节点u的置信度,u′为从集合Nv中选取的任一节点,Nv为节点v和节点v的邻居组成的集合,∈为属于,∪为并集,{v}为节点v的集合,cu′为节点u′的置信度;
针对所有节点,根据如下特征变换公式:
fFT(v)=softmax(MLP(Xv)),将每个节点的特征变换为所述无标签点集的软标签的预测;
其中,fFT(v)为特征变换函数,softmax为归一化函数,FT为特征变换,MLP(.)为多层感知器;
根据如下公式:
Figure BDA0002916368350000052
更新每一层所述无标签点集的软标签的预测,得到所述无标签点集的软标签的第二预测结果;
其中,
Figure BDA0002916368350000053
作为一个整体,表示标签传播公式,
Figure BDA0002916368350000054
为对节点v执行第k+1层CPF函数,αv为平衡参数,u为节点名称,Nv为节点v的邻居,∈为属于,v为任一节点名称,∪为并集运算,{v}为节点v构成的集合,
Figure BDA0002916368350000055
为对节点u执行第k层CPF函数,fFT(v)为对节点v执行FT函数。
进一步的,所述第二处理模块,用于:
将所述第一预测结果代入如下公式:
Figure BDA0002916368350000056
优化所述学生模型的无标签点集的软标签的预测,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;
其中,K为所述学生模型的总层数,min为最小化,Θ为参数集合,fGNN(v)为节点v的教师模型软标签,
Figure BDA0002916368350000057
为节点v的第k层CPF模型,CPF;Θ为模型及其参数,;为分隔符,‖.‖2为L2范数。
进一步的,所述装置还包括:预测模块,用于:
获取数据集,以及图结构中的节点集和边集;
通过所述图结构及所述数据集预训练,得到图神经网络教师模型;所述教师模型为图神经网络GNN分类器;
采用所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果。
第三方面,本发明实施例提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现上述第一方面任一的方法的步骤。
第四方面,本发明实施例提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述第一方面任一的方法。
本发明实施例有益效果:
本发明实施例提供的一种基于图神经网络知识蒸馏的图节点分类方法及装置,将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,通过使用教师模型的第一预测结果优化学生模型的,不需要教师模型和学生模型之间的集成或迭代,也能够提高分类精度,简化了优化学生模型的过程;学生模型采用标签传播公式,通过有标签的节点传播到相邻无标签的节点,可以从图结构的先验知识中受益;也通过特征变换公式,预测无标签点集的软标签,可以从特征的先验知识中受益,即数据集中具有硬标签的点集及无标签点集的特征。从而充分的使用先验知识,提高学生模型的分类效果,从而提高分类精度。
当然,实施本发明的任一产品或方法并不一定需要同时达到以上所述的所有优点。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的基于图神经网络知识蒸馏的图节点分类方法的构架示意图;
图2为本发明实施例提供基于图神经网络知识蒸馏的图节点分类方法的第一流程示意图;
图3为本发明实施例提供学生模型的示意图;
图4(a)为本发明实施例的Cora数据集上具有不同数量传播层的CPF-ind的分类精度;
图4(b)为本发明实施例的Cora数据集上具有不同数量传播层的CPF-tra的分类精度;
图5(a)为本发明实施例的GCN中每个类的标签节点数量从5个变化到50个的示意图;
图5(b)为本发明实施例的GAT中每个类的标签节点数量从5个变化到50个的示意图;
图5(c)为本发明实施例的APPNP中每个类的标签节点数量从5个变化到50个的示意图;
图5(d)为本发明实施例的SAGE中每个类的标签节点数量从5个变化到50个的示意图;
图5(e)为本发明实施例的SGC中每个类的标签节点数量从5个变化到50个的示意图;
图5(f)为本发明实施例的GCII中每个类的标签节点数量从5个变化到50个的示意图;
图5(g)为本发明实施例的GLP中每个类的标签节点数量从5个变化到50个的示意图;
图6(a)为本发明实施例的大GCN可解释性分析的平衡参数αv案例研究的第一示意图;
图6(b)为本发明实施例的大GAT可解释性分析的平衡参数αv案例研究的第二示意图;
图6(c)为本发明实施例的小GCN可解释性分析的平衡参数αv案例研究的第三示意图;
图6(d)为本发明实施例的小GAT可解释性分析的平衡参数αv案例研究的第四示意图;
图7(a)为本发明实施例的大GCN可解释性分析的置信度得分cv案例研究的第一示意图;
图7(b)为本发明实施例的大GAT可解释性分析的置信度得分cv案例研究的第二示意图;
图7(c)为本发明实施例的小GCN中可解释性分析的置信度得分cv案例研究的第三示意图;
图7(d)为本发明实施例的小GAT中可解释性分析的置信度得分cv案例研究的第四示意图;
图8为本发明实施例的基于图神经网络知识蒸馏的图节点分类装置的第一结构示意图;
图9为本发明实施例提供的电子设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
目前,知识蒸馏与图卷积网络(Graph convolutional networks,简称GCN)结合在一起进行应用。比如,可靠数据蒸馏(Resilient Distributed Dataset,简称RDD)用相同的体系结构,训练了多个GCN学生,然后对这些GCN学生进行集成以获得更好的性能。这样两个接收大小不同的GCN相互学习。其中,教师模型和学生模型,都是GCN。现有技术使用GCN作为学生模型,在融合结构和特征信息上存在缺陷,使得学生模型没有充分利用图结构和特征先验知识,使得学习的学生模型分类效果较差。
而,本发明实施例的教师模型的知识在于其软预测,学生模型能够进一步利用预训练的GNN中的知识,这样知识蒸馏可以将额外的先验知识引入学生模型,并且保留教师模型的知识。因此,学习的学生模型简化了优化学生模型的过程,具有更可解释的预测过程,并且可以利用GNN和基于图结构/特征的先验知识,提高学生模型的分类效果,从而提高分类精度。
故,为了解决上述RDD过程中的集成这些GCN学生增加了学习的学生模型的流程的问题;以及目前使用GCN作为学生模型的分类效果较差的问题,本发明实施例提供了一种基于图神经网络知识蒸馏的图节点分类方法及装置,将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,通过使用教师模型的第一预测结果优化学生模型的,不需要教师模型和学生模型之间的集成或迭代,也能够提高分类精度,简化了优化学生模型的过程;学生模型采用标签传播公式,通过有标签的节点传播到相邻无标签的节点,可以从图结构的先验知识中受益;也通过特征变换公式,预测无标签点集的软标签,可以从特征的先验知识中受益,即数据集中具有硬标签的点集及无标签点集的特征。从而充分的使用先验知识,提高学生模型的分类效果,从而提高分类精度。
下面首先对本发明实施例提供的一种基于图神经网络知识蒸馏的图节点分类方法进行介绍。
本发明实施例所提供的一种基于图神经网络知识蒸馏的图节点分类方法,应用于社交网络分析等应用场景。例如:将图神经网络应用于社交网络图,并将其节点分类预测使用本装置进行蒸馏。
如图1所示,学生模型的两种简单预测机制可确保充分利用基于图结构的先验知识和功能的先验知识。在知识蒸馏过程中,将提取图神经网络(Graph Neural Networks,简称GNN)教师模型中的知识(即教师模型的输出)并将其注入学生模型(即学生模型的输入)。因此,学生模型可以超越其相应的老师模型,得到更有效和可解释的预测。
如图2所示,本发明实施例所提供的基于图神经网络知识蒸馏的图节点分类方法,该方法可以包括如下步骤:
步骤110,获取图神经网络教师模型,数据集,以及图结构中的节点集和边集,其中,教师模型为GNN分类器,并且通过图结构及数据集预训练得到的,教师模型用于预测无标签点集的软标签可能为硬标签的第一概率,得到无标签点集的软标签的第一预测结果;数据集包括:具有硬标签的点集及无标签点集,每个节点表示数据集中的一个点;点为基准引文或购买商品。
需要说明的是:图结构上的半监督学习旨在基于给定网络结构和一个标签点集,为每个节点提供分类。GNN已经证明了图结构数据在分类节点标签方面的有效性。大多数GNN模型采用消息传递策略:每个节点从其邻域聚合特征,然后将具有非线性激活的分层映射函数应用于聚合信息。这样,GNN可以在其模型中利用图结构和节点特征信息。
本发明实施例的上述GNN分类器可以为任意学习的GNN分类器比如GCN或GAT,并将GNN分类器注入学生模型中。学生模型是通过两个预测机制构建的,即标签传播预测和特征变换预测。这样通过标签传播预测和特征变换预测分别保留了基于图结构的先验知识和基于特征的先验知识。这样将学生模型设置为标签传播预测和特征变换预测的可训练组合。这使学习的学生模型(即学生模型)可以从先验知识和GNN教师的知识中受益,以获得更有效的预测。此外,与GNN相比,学习的学生模型具有更可解释的预测过程。
为了能够通过教师模型能够优化学生模型,因此需要获得教师模型的输入,即数据集,以及图结构中的节点集和边集;需要获得教师模型的输出,即,预测所有无标签点集的软标签可能为所述硬标签的第一概率(也称为第一预测结果)。
当然,本发明实施例可以采用如下方式获得教师模型,比如直接不用训练,获得任一教师模型对无标签点集的软标签进行预测的第一预测结果。然后,使用与教师模型相同的数据集、相同的图结构中的节点集和边集,以及图神经网络模型的输出,优化后续的学生模型。再比如,获取数据集,以及图结构中的节点集和边集;将数据集,以及图结构中的节点集和边集作为待训练教师模型的输入,通过待训练教师模型进行预训练,得到上述教师模型,所述上述教师模型用于预测无标签点集的软标签可能为所述硬标签的第一概率。
本发明实施例中,采用如下步骤,所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果:获取数据集,以及图结构中的节点集和边集;通过所述图结构及所述数据集预训练,得到图神经网络教师模型;所述教师模型为图神经网络GNN分类器;采用所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果。
其中,数据集包括:基准引文或购买商品。比如所述基准引文可以包括如下一者。由机器学习论文构成的基准引文数据集、由计算机科学论文构成的基准引文数据集及引文数据集。购买商品可以包括如下一者:购买图书。参见如下表1的数据集统计信息。
表1数据集的统计数据
数据集 节点数 边数 特征数 类别数
Cora 2,485 5,069 1,433 7
Citeseer 2,110 3,668 3,703 6
Pubmed 19,717 44,324 500 3
A-Computers 13,381 245,778 767 10
A-Photo 7,487 119,043 745 8
本发明实施例可以使用五个公共基准数据集进行实验,数据集的统计数据如表1所示。本发明实施例仅考虑最大的连通分量,并将边视为无向边。有关数据集的详细信息如下:
Cora是一个由机器学习论文构成的基准引文数据集,其中每个节点表示一篇论文,节点的特征用稀疏词袋向量表示。边表示论文的引用关系,标签表示每篇论文的研究领域。
Citeseer是另一个由计算机科学论文构成的基准引文数据集,和Cora的设置相似。在本发明实施例使用的5个数据集中,Citeseer数据集中节点的特征最多。
Pubmed是一个引文数据集,由PubMed数据库中和糖尿病相关的论文构成。节点的特征是TF/IDF加权词频向量,标签表示论文中讨论的糖尿病类型。
A-Computers和A-Photo从Amazon共同购买图中提取得到,其中节点表示商品,边表示是否两个商品经常被共同购买,节点的特征表示由词袋模型编码的商品评论,标签是预定义的商品类别。本发明实施例在五个公开基准数据集上进行了实验,并使用了七个GNN模型(包括GCN,图形注意力网络(Graph Attention Network,简称GAT),图采样聚合网络(SAmple and aggreGatE,简称SAGE);基于个性化PageRank的图卷积神经网络的快速近似(fast approximation version of personalized propagation of neuralpredictions,简称APPNP);简化图卷积神经网络(Simple Graph Convolution,简称SGC);使用初始残差和恒等映射的图卷积网络(Graph Convolutional Network via Initialresidual and Identity mapping,简称GCNII);广义标签传播(generalized labelpropagation,简称GLP)作为教师模型。实验结果表明,学习的学生模型的平均表现优于教师模型1.4%-4.7%。性能提升始终存在,更重要的是,预测还具备更好的可解释性。
步骤120,基于第一预测结果,数据集及图结构,采用学生模型预测无标签点集的软标签,可能为硬标签的第二概率,得到无标签点集的软标签的第二预测结果;所述学生模型是基于所有硬标签的分布,采用标签传播公式,预测有标签的节点传播到相邻无标签的节点的概率分布,以及基于每个节点的特征,采用特征变换公式,预测无标签点集的软标签进行组合训练的。
上述步骤120中,基于所有硬标签的分布,采用标签传播公式,预测有标签的节点传播到相邻无标签的节点的概率分布,可以称为标签传播LP预测,其是基于图的经典半监督学习模型。该模型仅遵循以下假设:由边连接(或占据相同流形)的节点极有可能共享相同的标签。基于此假设,标签将从标签的节点传播到未标签的节点以进行预测。通过边缘传播标签的参数化标签传播(Parameterized Label Propagation,简称PLP)模块强调了基于图结构的先验知识。因此,本发明实施例还引入了特征变换(Feature Transformation,简称FT)模块作为补充预测机制,即基于每个节点的特征,采用特征变换公式,预测无标签点集的软标签。FT模块仅通过查看节点的特征,预测标签。学生模型是两个模型的组合,即PLP模型及FT模型。
通过训练轻量级学生模型来模仿预训练的教师模型的软预测,称为知识蒸馏。这样知识蒸馏被提出用于模型压缩,学生模型可以减少时间和空间的复杂性,同时又不损失预测的质量。
步骤130,将学生模型得到的第二预测结果,拟合教师模型得到的第一预测结果,直至学生模型得到的第二预测结果与教师模型得到的第一预测结果最相近,得到训练好的学生模型;训练好的学生模型用于对无标签点集的软标签的预测,得到基准引文的类型或购买商品的类型。
上述步骤130中,所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近的确定方式可以是通过最小化库尔巴克-莱布勒散度(Kullback-LeiblerDivergence,简称K-L散度),最大化交叉熵或者欧氏距离,确定学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近。可选的,本发明实施例使用欧几里得距离,欧几里得距离的效果最好,并且在数值上更稳定。
在本发明实施例中,将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,通过使用教师模型的第一预测结果优化学生模型的,不需要教师模型和学生模型之间的集成或迭代,也能够提高分类精度,简化了优化学生模型的过程;学生模型采用标签传播公式,通过有标签的节点传播到相邻无标签的节点,可以从图结构的先验知识中受益;也通过特征变换公式,预测无标签点集的软标签,可以从特征的先验知识中受益,即数据集中具有硬标签的点集及无标签点集的特征。从而充分的使用先验知识,提高学生模型的分类效果,从而提高分类精度。
如图3所示,本发明实施例的学生模型和教师模型构成的基于图神经网络知识蒸馏的图节点分类系统(即知识蒸馏系统)进行如下介绍。
图结构包括节点集和边集,连通图就是指图结构提取出的最大连通分量。本发明实施例中,连通图G=(V,E)和一个标签点集
Figure BDA0002916368350000123
其中,V是节点集,E是边集,节点分类的目标是为每个节点无标签点集VU=V\VL中的节点u预测标签。每个节点v∈V拥有标签yv∈Y,其中Y是所有可能的标签集合。此外,图数据通常拥有节点特征
Figure BDA0002916368350000121
并且可以利用节点的特征来提升分类准确率。每行矩阵X的每行
Figure BDA0002916368350000122
表示节点v的d维特征向量。
基于GNN的节点分类方法往往是一个黑盒,输入图结构G、标签点集VL和节点特征X,输出GNN分类器f。GNN分类器f将预测无标签点v∈VU的标签为y∈Y的概率f(v,y),其中∑y,∈Yf(v,y′)=1。对于标签节点v,如果v的标签为y,那么f(v,y)=1,其余标签f(v,y′)=0。简化起见,本发明实施例使用
Figure BDA0002916368350000131
表示所有标签的概率分布,其包括预测无标签点集的软标签可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果。
本发明实施例中的教师模型里的预训练分类器可以称为为fGNN。使用fSTU;Θ表示学生模型,Θ是参数,
Figure BDA0002916368350000132
表示学生模型对节点v的预测概率分布。
在知识蒸馏系统中,训练学生模型使其最小化与预训练教师模型的软标签预测,使得教师模型里的潜在知识被提取并注入学生模型中。因此,优化目标是对齐学生模型和与训练教师模型的输出,可以形式化为:
Figure BDA0002916368350000133
上述公式(1),distan(·,·)度量两个预测概率分布之间的距离,任意中心节点v,fGNN(v)为节点v的教师模型软标签,fSTU;Θ(v)为节点v的CPF模型,Θ为参数集合Θ为参数集合。
在知识蒸馏系统中,学生模型使得节点的标签预测遵循两种简单的机制:(1)从该节点的相邻节点传播标签;(2)从该节点自身特征进行转换。因此,如图3所示,本发明实施例将学生模型设计为两种机制的组合,即参数化标签传播(Parameterized LabelPropagation,简称PLP)模型和特征变换(Feature transformation,简称FT)模块,这样可以自然地分别保留基于图结构的先验知识的先验知识和基于特征的先验知识。蒸馏后,学生将通过更易于解释的预测机制从GNN和先验知识中受益。
在图3中,以任意中心节点v为例,学生模型从任意中心节点v的原始特征和统一的标签分布作为软标签开始,然后在每一层,将任意中心节点v的软标签预测更新为来自任意中心节点v的邻居的PLP和任意中心节点v的FT的可训练组合。最终,将使学生与经过训练的教师的软标签预测之间的距离最小化。通过本发明实施例的系统,以提取任意预训练的GNN模型的知识,并将其注入学生模型,以实现更有效和可解释的预测。
下面继续将首先介绍初始化。接着,继续将介绍PLP模块和FT模块,及两者可训练的组合。
在所述基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果之前,所述方法还包括如下过程:
本发明实施例使用fLP表示LP的最终预测,使用
Figure BDA0002916368350000141
表示k轮迭代后的LP预测。在这个工作中,如果v是标签节点,本发明实施例将对节点v的预测初始化为一个独热编码向量。否则,本发明实施例将为每个未标签的节点v设置均匀分布,这表明所有类的概率在开始时都是相同的。初始化可以形式化,即
根据如下公式:
Figure BDA0002916368350000142
公式(2),为每个节点初始化标签预测,以完成初始化学生模型;
其中,
Figure BDA0002916368350000143
为节点v在第k次迭代中的预测概率分布,
Figure BDA0002916368350000144
为节点v在初始化的预测概率分布,LP为标签传播的英文简称,∈为属于,
Figure BDA0002916368350000145
为实数集合,Y为标签集合,|.|为集合的基数,
Figure BDA0002916368350000146
为任意,VL为有标签节点集合,V为所有节点,L为有标签,VU为无标签节点集合,U为无标签,fLP为所述第一预测结果中的LP的最终预测。
在本发明实施例中,直接获得标签传播公式,在公式(8)中进行使用,或者说公式(8)直接获得标签传播公式的结果。其中上述公式(8)中的标签传播公式的推导过程如下:
在第k+1次迭代时,LP将按照如下方式更新无标签节点v∈VU的预测:
Figure BDA0002916368350000147
其中,Nv时节点v的邻居集合,λ是控制节点更新平滑度的超参。
注意LP没有需要训练的参数,因此以端到端的方式不能拟合教师模型的输出。因此,本发明实施例通过引入更多参数来提升LP的表达能力。
本发明实施例将通过在LP中进一步参数化边缘权重来介绍本发明实施例的参数化标签传播(PLP)模块。如公式(3)所示,LP模型在传播过程中平等对待节点的所有邻居。但是,本发明实施例假设不同邻居对一个节点的重要性应该不同,这决定了节点之间的传播强度。更具体地说,本发明实施例假设某些节点的标签预测比其他节点更“自信”。例如,一个节点的预测标签与其大多数邻居相似。这样的节点将更有可能将其标签传播给邻居,并使它们保持不变。
形式化来说,本发明实施例将给每个节点v设置一个置信度
Figure BDA0002916368350000148
在传播过程中,所有节点v的邻居和v自身将把他们的标签传播给v。基于置信值越大,边缘权值越大的直觉,本发明实施例为fPLP重写了公式(3)中的预测更新函数如下:
Figure BDA0002916368350000151
上述公式(4)重写了公式(3),后续可以通过公式(8)重写公式(4)。
基于上述公式(4)的推导过程,在一种可能的实现方式中,所述采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果,包括:
作为可选项,本发明实施例可以进一步参数化置信度cv用于归纳设置,即基于每个节点的特征,根据如下公式:
cu=zTXu公式(6),将每个节点的特征映射为所述每个节点的置信度;
其中,cu为所述每个节点的置信度,u为任一节点名称,z为可学习参数,
Figure BDA0002916368350000152
是一个可学习参数,∈为属于,
Figure BDA0002916368350000153
为d维实数集合,d为维度大小,T为转置,Xu为任一节点u的特征。如果不需要归纳设置,则可以不用执行公式(6)。
根据如下公式:
Figure BDA0002916368350000154
公式(5)为每条边计算边权,其中,所述边为每两个节点之间的边;
其中,与LP相似,
Figure BDA0002916368350000155
按照公式(2)初始化,在传播过程中,每个标签点v∈VL
Figure BDA0002916368350000156
仍然保持独热真实编码向量。wuv为每两个节点u与节点v之间的边的边权,exp为指数函数运算符,cu为节点u的置信度,u′为从集合Nv中选取的任一节点,Nv为节点v和节点v的邻居组成的集合,∈为属于,∪为并集,{v}为节点v的集合,cu′为节点u′的置信度。
由于FT模块仅通过查看节点的原始特征来预测标签。形式化来说,用fFT表示FT模块的预测,使用两层MLP后接一个softmax函数来将特征变换为软标签预测,即针对所有节点,根据如下特征变换公式:
fFT(v)=softmax(MLP(Xv))公式(7),将每个节点的特征变换为所述无标签点集的软标签的预测;
注:虽然单层逻辑回归更具可解释性,但两层逻辑回归对于提高学生的模型能力是必要的。其中,fFT(v)为特征变换函数,softmax为归一化函数,FT为特征变换,MLP(.)为多层感知器。
为了能够可训练组合特征变换和传播标签,本发明实施例将结合PLP和FT模块作为学习的学生模型。本发明实施例将为每个节点v学习一个可训练参数αv∈[0,1],来平衡PLP和FT之间的预测。换句话说,FT和PLP的预测将在每个传播步骤合并。本发明实施例将合并后的完整模型命名为CPF,公式(4)中的每个无标签节点v∈VU的预测更新公式可以重新写做,即根据如下公式:
Figure BDA0002916368350000161
公式(8)更新每一层所述无标签点集的软标签的预测,得到所述无标签点集的软标签的第二预测结果;
边权wuv和初始化
Figure BDA0002916368350000162
与PLP模块一致。根据是否按照公式(6)参数化置信度Cv,模型有两个变体,分别是归纳模型CPF-ind和转导模型CPF-tra。其中,
Figure BDA0002916368350000163
作为一个整体,表示标签传播公式,
Figure BDA0002916368350000164
为对节点v执行第k+1层CPF函数,αv为平衡参数,u为节点名称,Nv为节点v的邻居,∈为属于,v为任一节点名称,∪为并集运算,{v}为节点v构成的集合,
Figure BDA0002916368350000165
为对节点u执行第k层CPF函数,fFT(v)为对节点v执行FT函数。
在本发明实施例中,对标签传播的概率分布和通过特征变换公式,预测无标签点集的软标签之间的加权平均值进行预测,平衡参数指示基于图结构的PLP还是基于特征的MLP对于节点的预测更重要。这样可以轻松地找出节点在每个迭代中受哪个邻居影响的程度。这样使得模型具备一定的可解释性,有助于科研人员改进模型。另一方面,对基于特征的MLP的理解可以通过现有技术或直接查看不同特征的梯度来获得。因此,学习过的学生模型比GNN教师具有更好的解释性。
在一种可能的实现方式中,所述将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型,包括:
将所述第一预测结果代入如下公式:
Figure BDA0002916368350000166
公式(9)优化所述学生模型的无标签点集的软标签的预测,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;
假设本发明实施例的学生模型一共有K层,上述公式(1)中的蒸馏目标重写为上述公式(9)。其中,K为所述学生模型的总层数,min为最小化,Θ为参数集合,fGNN(v)为节点v的教师模型软标签,
Figure BDA0002916368350000167
为节点v的第k层CPF模型,CPF;Θ为模型及其参数,;为无含义,仅为区分前后参数的分隔符,‖.‖2为L2范数。PLP模块内部的置信度参数
Figure BDA0002916368350000168
(或归纳设置下的参数z),以及FT模块中MLP的参数ΘMLP。还有一个重要的超参数:传播层数K。
通过算法1展示了训练过程的伪代码。
输入:图G=(V,E),有标签节点集合
Figure BDA0002916368350000171
无标签节点集合
Figure BDA0002916368350000172
节点特征X,预训练GNN分类器fGNN(也可以称为教师模型)。
输出:学习的学生分类器fCPF(也可以称为学生模型)。
第1步骤,根据公式(2),为每个节点v初始化标签预测
Figure BDA0002916368350000173
第2步骤,while没有收敛do;
第3步骤,if归纳设置then;
第4步骤,根据公式(6),为每个节点v∈V计算置信度Cv
第5步骤,end if;
第6步骤,根据公式(5),为每条边(u,v)∈E计算边权;
第7步骤,for所有节点v∈VUdo;
第8步骤,根据公式(7),计算FT模块的预测fFT(v);
第9步骤,for k=1,2,…,K do;
第10步骤,根据公式(8),更新第k层的预测
Figure BDA0002916368350000174
第11步骤,end for;
第12步骤,end for;
第13步骤,根据公式9的优化目标更新参数;
第14步骤,end while。
本发明实施例中的u,v,a,b,c均表示任一节点。
在训练前,需要根据公式(2),为每个节点v初始化标签预测
Figure BDA0002916368350000175
然后,执行如下对学生模型的训练过程直至模型收敛:首先,如果是归纳设置,本发明实施例需要根据公式(6),为每个节点v∈V计算置信度分数cv。接下来,根据公式(5),为每条边(u,v)∈E计算边权。然后,对于所有节点v∈VU,根据公式(7),计算FT模块的预测fFT(v),并根据公式(8),更新每一层的预测fCPF(v),最后,根据公式(9)的优化目标更新参数,直至模型收敛。训练结束后,将测试数据输入学生分类器fCPF,得到节点分类结果。
上述算法1每次迭代算法1的第3行到第13行的时间复杂度和空间复杂度都是O(|E|+d|V|),这和数据集的规模线性相关。事实上,操作可以简单写成矩阵形式,对于真实数据集的训练过程,使用单GPU可以在几秒内完成。因此,本发明实施例提出的知识蒸馏系统的时间、空间效率都很高。模型复杂度低,所以时间、空间效率高。
为了说明本发明实施例的实现效果,进行如下描述:
在本节中,将从介绍实验中使用的数据集和教师模型开始。然后,本发明实施例将详细介绍教师模型和学生变体的实验设置。之后,本发明实施例将给出评估半监督节点分类的定量结果。本发明实施例还在不同数量的传播层和训练比率下进行实验,以说明算法的鲁棒性。最后,将提供定性案例研究和可视化效果,以更好地理解本发明实施例的学生模型CPF中的学习参数。
基于上述步骤110中的数据集,从每个类别中随机抽取20个节点作为标签节点,30个用于验证节点,所有其他节点用于测试。
(1)教师模型及其设置。
为了进行全面比较,在本发明实施例的知识蒸馏系统中考虑了七个GNN模型作为教师模型:
GCN是一个经典的半监督模型,通过在图结构数据上定义卷积核学习节点表示。GCN对层数敏感,本发明实施例使用2层的常用设置。
GAT通过引入注意力机制为每个节点的不同邻居分配不同的权重,从而改进GCN。本发明实施例使用2层GAT和8个注意力头作为教师模型。
APPNP通过平衡局部信息和更大范围内的邻居信息来提升GCN。本发明实施例使用2层和10次迭代的APPNP。
SAGE通过采样和聚合节点的局部邻居信息来学习节点嵌入。本发明实施例使用SAGE-GCN的变体作为教师模型。
SGC通过剔除GCN层之间的非线性变换以及权重矩阵的压缩来降低GCN的算法复杂度。与GCN类似,本发明实施例也使用2层设置。
GCNII是一个深层模型,通过使用初始残差和恒等映射来避免GCN模型的过平滑问题。这里本发明实施例使用16层的GCNII作为教师模型。
GLP是一个标签利用高效的模型,使用图滤波系统结合了标签传播和图卷积操作。GLP有两个模型变体:GLP-RNM和GLP-AR,对于每个数据集,本发明实施例都使用效果更好的变体作为教师模型。
(2)、学生模型的变体和实验设置。
对于每个数据集和教师模型,本发明实施例测试下列学生变体:
PLP:只考虑参数化标签传播机制的学生变体;
FT:只考虑特征变换机制的学生变体;
CPF-ind:归纳设置下的完整模型;
CPF-tra:转导设置下的完整模型。
本发明实施例随机初始化参数,使用忍耐度为50的早停法,也就是说,如果验证集上的分率正确率在50个epoch内不在上升就停止训练。对于超参优化,本发明实施例使用启发式搜索策略,搜索空间包括层数K∈{5,6,7,8,9,10},隐层维度dMLP∈{8,16,32,64},随机失活概率dr∈{0.2,0.5,0.8},学习率lr∈{0.001,0.005,0.01}和Adam优化器的权重衰减概率wd∈{0.0005,0.001,0.01}。本发明实施例承诺在论文接受后提供代码和数据划分以供复现。
(3)、分类结果分析。
表2 GCN和GAT作为教师模型的分类准确率
Figure BDA0002916368350000191
表3 APPNP和SGAE作为教师模型的分类准确率
Figure BDA0002916368350000192
表4 SGC和GCNII作为教师模型的分类准确率
Figure BDA0002916368350000193
表5 GLP作为教师模型的分类准确率
Figure BDA0002916368350000194
五个数据集、七个GNN教师模型、四个学生变体模型上的实验结果在表格2,3,4,5中展示。(注:本发明实施例在A-Computer/A-Photo上省略了GLP的结果,因为在本发明实施例的实验中,GLP在这两个数据集上的表现比其他GNN模型差得多。)
基于以上分析,得出如下结论:
借助学生模型CPF-ind和CPF-tra的完整体系结构,能够一致且显着地改善相应教师模型的性能。例如,将Cora数据集上GCN的分类精度从0.8244提高到0.8576。这是因为可以提取GNN教师的知识并将其注入本发明实施例的学生模型,这也得益于其简单的预测机制引入的基于图结构/功能的先验知识。此观察证明了本发明实施例的动机和系统的有效性。
教师模型广义标签传播GLP已经在其图过滤器中加入了标签传播机制。如表5所示,通过应用本发明实施例的知识蒸馏系统,本发明实施例仍然可以获得1.5%-2.3%的相对改进,这说明了本发明实施例的算法的潜在兼容性。
在这四个学生变体中,完整模型CPF-ind和CPF-tra始终表现最佳(Pubmed数据集上的APPNP老师除外),并给出了具有竞争力的结果。因此,基于图结构的PLP和基于功能的FT模块都将为整体改进做出贡献。PLP本身表现最差,因为要学习的参数很少的PLP的模型容量较小,无法满足教师模型的软预测。
七个教师模型GCN/GAT/APPNP/SAGE/SGC/GCNII/GLP的平均相对改进分别为3.9/3.2/1.4/4.7/4.1/1.9/2.0%。APPNP的改进是最小的。一个可能的原因是APPNP在消息传递时保留了节点的特征,因此同样利用了基于特征的先验知识,正如FT模块所做。
五个数据集Cora/Citeseer/Pubmed/A-Computers/A-Photo的平均相对改进分别为2.9/4.2/2.7/2.1/2.7%。Citeseer数据集提升最明显。可能的原因是Citeseer具有最多的特征,因此学生模型还具有更多可训练的参数以提升能力。这样五个基准数据集和七个GNN教师模型上的实验结果表明了本发明实施例的系统有效性。对学生模型中学习权重的广泛研究也说明了本发明实施例方法的可解释性。
在不同传播层数的分析中,将研究关键超参数对学生模型CPF的体系结构(即传播层数)的影响。实际上,流行的GNN模型(例如GCN和GAT)对层数非常敏感。较大数量的层将导致过平滑的问题,并严重损害模型性能。因此,本发明实施例在Cora数据集上进行了实验,以进一步分析该超参数。
图4(a)和图4(b)给出了传播层数K∈{5,6,7,8,9,10}下学生模型CPF-ind和CPF-tra的分类结果,可以看出不同K之间的差距相对较小:对于每位教师,本发明实施例计算其相应学生的最佳和最差成绩之间的差距,对于CPF-ind和CPF-tra而言,最大差距分别为0.56%和0.84%。此外,在最差的K∈{5,6,7,8,9,10}下,CPF的准确性已经超过了相应的教师。因此,当传播层的数量在合理范围内变化时,本发明实施例系统的性能提升非常鲁棒。
在不同训练比例的分析中,为了进一步证明该系统的有效性,本发明实施例在不同的训练比例下进行了额外的实验。具体来说,本发明实施例以Cora数据集为例,将每个类的标签节点数量从5个变化到50个。实验结果如图5(a)、图5(b)、图5(c)、图5(d)、图5(e)、图5(f)及图5(g)所示。请注意,本发明实施例省略了PLP的结果,因为它的性能很差,无法与数据吻合。
本发明实施例可以看到,学到的CPF-ind和CPF-tra的学生在每类使用不同数量的标签节点的情况下,在性能上始终优于预训练的GNN教师模型,这说明了本发明实施例系统的稳定性。与之相比,FT模块具有足够的模型能力来适应教师的预测,但没有进一步的改进。因此,作为补充的预测机制,PLP模块在本发明实施例的系统中也非常重要。
另一个观察结果是,对于少样本设置(即每类仅标签5个节点),学生相对于相应教师模型的改进更为显着。显然,每类5/10/20/50个标签节点的分类准确度平均提高了4.9/4.5/3.2/2.1%。因此,本发明实施例的算法还具有处理少样本设置的能力,这是半监督学习中的重要研究问题。
在可解释性分析中,将分析学习的学生模型CPF的可解释性。具体来说,本发明实施例将探究PLP和FT之间的学习平衡参数αv以及每个节点的置信度得分cv。本发明实施例的目标是找出哪种节点具有最大或最小的αv和cv。在本小节中,本发明实施例将使用由GCN和GAT教师模型指导的CPF-ind学生模型在Cora数据集上进行展示。
平衡参数αv用于指示基于图结构的LP还是基于特征的MLP为节点v的预测做出更多贡献。如图6(a)、图6(b)、图6(c)及图6(d)所示,此处图6(a)、图6(b)、图6(c)及图6(d)的子标题表示该节点是按GCN/GAT作为教师模型,按大或小值选择的。本发明实施例分析了具有最大/最小的前十个节点,并选择了四个代表性节点进行案例研究。本发明实施例绘制每个节点的一阶邻域,并使用不同的形状图案指示不同的预测标签。本发明实施例发现,具有更大αv的节点更有可能具有相同的预测邻居。相反,一个αv较小的节点可能会有更多具有不同预测标签的邻居。该观察结果符合本发明实施例的直觉,即节点的预测如果具有许多带有各种预测标签的邻居,将被混淆,因此不能从标签传播中受益良多。
置信度得分cv在本发明实施例的学生体系结构中具有较高置信度得分的节点将具有较大的边缘权重,以将其标签传播到邻居并保持自身不变。类似地,如图7(a)、图7(b)、图7(c)及图7(d)所示,本发明实施例还研究了置信度得分最大/最小的前10个节点,并选择四个有代表性的节点进行案例研究。本发明实施例可以看到,具有高置信度的节点也将具有相对较小的度,并且具有相同的预测邻居。相反,具有低置信度的节点将比具有较小置信度的节点具有更多样化的邻域。直观地,节点的多样化邻域将导致较低的置信度来传播其标签。这一发现证实了本发明实施例建模节点置信度的动机。
下面继续对本发明实施例提供的一种基于图神经网络知识蒸馏的图节点分类装置进行介绍。
参见图8,图8为本发明实施例提供的一种基于图神经网络知识蒸馏的图节点分类装置的结构示意图。本发明实施例所提供的一种基于图神经网络知识蒸馏的图节点分类装置,可以包括如下模块:
获取模块21,用于获取图神经网络教师模型,数据集,以及图结构中的节点集和边集,其中,所述教师模型为图神经网络GNN分类器,并且通过所述图结构及所述数据集预训练得到的,所述教师模型用于预测无标签点集的软标签可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果;所述数据集包括:具有硬标签的点集及无标签点集,所述每个节点表示数据集中的一个点;所述点为基准引文或购买商品;
第一处理模块22,用于基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果;所述学生模型是基于所有硬标签的分布,采用标签传播公式,预测有标签的节点传播到相邻无标签的节点的概率分布,以及基于每个节点的特征,采用特征变换公式,预测无标签点集的软标签进行组合训练的;
第二处理模块23,用于将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;所述训练好的学生模型用于对无标签点集的软标签的预测,得到基准引文的类型或购买商品的类型。
在一种可能的实现方式中,所述装置还包括:初始化模块,用于:
在所述基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果之前,根据如下公式:
Figure BDA0002916368350000221
为每个节点初始化标签预测,以完成初始化学生模型。
在一种可能的实现方式中,所述第一处理模块,用于:
基于每个节点的特征,根据如下公式:
cu=zTXu,将每个节点的特征映射为所述每个节点的置信度;
根据如下公式:
Figure BDA0002916368350000222
为每条边计算边权,其中,所述边为每两个节点之间的边;
针对所有节点,根据如下特征变换公式:
fFT(v)=softmax(MLP(Xv)),将每个节点的特征变换为所述无标签点集的软标签的预测;
根据如下公式:
Figure BDA0002916368350000223
更新每一层所述无标签点集的软标签的预测,得到所述无标签点集的软标签的第二预测结果。
在一种可能的实现方式中,所述第二处理模块,用于:
将所述第一预测结果代入如下公式:
Figure BDA0002916368350000224
优化所述学生模型的无标签点集的软标签的预测,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型。
在一种可能的实现方式中,所述装置还包括:预测模块,用于:
获取数据集,以及图结构中的节点集和边集;
通过所述图结构及所述数据集预训练,得到图神经网络教师模型;所述教师模型为图神经网络GNN分类器;
采用所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果。
在本发明实施例中,将学生模型设计为参数化标签传播和基于特征的两层MLP的可训练组合。因此,学生模型有一个更可解释的预测过程,并自然地保留了基于图结构/特征的先验。因此,学习的学生模型可以同时利用GNN和先验知识。
本发明实施例提供了基于图神经网络知识蒸馏的图节点分类系统,即一种知识蒸馏系统,可以提取任意预训练的GNN(教师模型)的知识并将其注入精心设计的学生模型中。学生模型CPF被建立为两个简单预测机制的可训练组合:标签传播和特征变换,二者分别强调基于结构的先验知识和基于特征的先验知识。蒸馏后,学习的学生可以利用先验知识和GNN知识,从而超越GNN老师。在五个基准数据集上的实验结果表明,本发明实施例的系统可以通过更可解释的预测过程来一致,显着地改善所有七个GNN教师模型的分类精度。在不同数量的训练比率和传播层数上进行的附加实验证明了本发明实施例算法的鲁棒性。本发明实施例还提供了案例研究,以了解学生架构中学习到的平衡参数和置信度得分。
在未来的工作中,除了半监督节点分类之外,本发明实施例还将探索将本发明实施例的系统用于其他基于图的应用。例如,无监督节点聚类任务会很有趣,因为标签传播模式在没有标签的情况下不能应用。另一个方向是改进本发明实施例的系统,鼓励教师和学生模型互相学习,以取得更好的成绩。
下面继续对本发明实施例提供的电子设备进行介绍。
参见图9,图9为本发明实施例提供的电子设备的结构示意图。本发明实施例还提供了一种电子设备,包括处理器31、通信接口32、存储器33和通信总线34,其中,处理器31,通信接口32,存储器33通过通信总线34完成相互间的通信,
存储器33,用于存放计算机程序;
处理器31,用于执行存储器33上所存放的程序时,实现上述基于图神经网络知识蒸馏的图节点分类方法的步骤,在本发明一个可能的实现方式中,可以实现如下步骤:
获取图神经网络教师模型,数据集,以及图结构中的节点集和边集,其中,所述教师模型为图神经网络GNN分类器,并且通过所述图结构及所述数据集预训练得到的,所述教师模型用于预测无标签点集的软标签可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果;所述数据集包括:具有硬标签的点集及无标签点集,所述每个节点表示数据集中的一个点;所述点为基准引文或购买商品;
基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果;所述学生模型是基于所有硬标签的分布,采用标签传播公式,预测有标签的节点传播到相邻无标签的节点的概率分布,以及基于每个节点的特征,采用特征变换公式,预测无标签点集的软标签进行组合训练的;
将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;所述训练好的学生模型用于对无标签点集的软标签的预测,得到基准引文的类型或购买商品的类型。
上述电子设备提到的通信总线可以是PCI(Peripheral ComponentInterconnect,外设部件互连标准)总线或EISA(Extended Industry StandardArchitecture,扩展工业标准结构)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口用于上述电子设备与其他设备之间的通信。
存储器可以包括RAM(Random Access Memory,随机存取存储器),也可以包括NVM(Non-Volatile Memory,非易失性存储器),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括CPU(Central Processing Unit,中央处理器)、NP(Network Processor,网络处理器)等;还可以是DSP(Digital Signal Processing,数字信号处理器)、ASIC(Application Specific Integrated Circuit,专用集成电路)、FPGA(Field-Programmable Gate Array,现场可编程门阵列)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
本发明实施例提供的方法可以应用于电子设备。具体的,该电子设备可以为:台式计算机、便携式计算机、智能移动终端、服务器等。在此不作限定,任何可以实现本发明实施例的电子设备,均属于本发明的保护范围。
本发明实施例提供了一种计算机可读存储介质,所述存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现上述的基于图神经网络知识蒸馏的图节点分类方法的步骤。
本发明实施例提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述的基于图神经网络知识蒸馏的图节点分类方法的步骤。
本发明实施例提供了一种计算机程序,当其在计算机上运行时,使得计算机执行上述的基于图神经网络知识蒸馏的图节点分类方法的步骤。
需要说明的是,在本发明实施例中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置/系统/电子设备/存储介质/包含指令的计算机程序产品/计算机程序实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本发明的较佳实施例,并非用于限定本发明的保护范围。凡在本发明的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本发明的保护范围内。

Claims (8)

1.一种基于图神经网络知识蒸馏的图节点分类方法,其特征在于,包括:
获取图神经网络教师模型,数据集,以及图结构中的节点集和边集,其中,所述数据集包括:基准引文或购买商品;所述教师模型为图神经网络GNN分类器,并且通过所述图结构及所述数据集预训练得到的,所述教师模型用于预测无标签点集的软标签可能为硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果;所述数据集包括:具有硬标签的点集及无标签点集,所述每个节点表示数据集中的一个点;所述点为基准引文或购买商品;其中,所述点为基准引文时,所述基准引文包括:机器学习论文、计算机科学论文,所述每个节点表示一篇论文,所述每个边表示论文的引用关系,标签表示论文的研究领域;所述点为购买商品时,所述每个节点表示商品,所述每个边表示是否两个商品经常被共同购买,标签表示商品类别;
基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果;所述学生模型是基于所有硬标签的分布,采用标签传播公式,预测有标签的节点传播到相邻无标签的节点的概率分布,以及基于每个节点的特征,采用特征变换公式,预测无标签点集的软标签进行组合训练的;
将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;所述训练好的学生模型用于对无标签点集的软标签的预测,得到基准引文的类型或购买商品的类型;
所述采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果,包括:
基于每个节点的特征,根据如下公式:
cuTXu,将每个节点的特征映射为所述每个节点的置信度;
其中,cu为所述每个节点的置信度,u为任一节点名称,z为可学习参数,
Figure QLYQS_1
是一个可学习参数,∈为属于,
Figure QLYQS_2
为d维实数集合,d为维度大小,T为转置,Xu为任一节点u的特征;
根据如下公式:
Figure QLYQS_3
为每条边计算边权,其中,所述边为每两个节点之间的边;
其中,wuv为每两个节点u与节点v之间的边的边权,exp为指数函数运算符,cu为节点u的置信度,u′为从集合Nv中选取的任一节点,Nv为节点v和节点v的邻居组成的集合,∈为属于,∪为并集,{v}为节点v的集合,cu′为节点u′的置信度;
针对所有节点,根据如下特征变换公式:
fFT(v)=softmax(MLP(Xv)),将每个节点的特征变换为所述无标签点集的软标签的预测;
其中,fFT(v)为特征变换函数,softmax为归一化函数,FT为特征变换,MLP(.)为多层感知器;
根据如下公式:
Figure QLYQS_4
更新每一层所述无标签点集的软标签的预测,得到所述无标签点集的软标签的第二预测结果;
其中,
Figure QLYQS_5
作为一个整体,表示标签传播公式,
Figure QLYQS_6
为对节点v执行第k+1层CPF函数,αv为平衡参数,u为节点名称,Nv为节点v的邻居,∈为属于,v为任一节点名称,∪为并集运算,{v}为节点v构成的集合,
Figure QLYQS_7
为对节点u执行第k层CPF函数,fFT(v)为对节点v执行FT函数。
2.如权利要求1所述的方法,其特征在于,在所述基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果之前,所述方法还包括:
根据如下公式:
Figure QLYQS_8
为每个节点初始化标签预测,以完成初始化学生模型;
其中,
Figure QLYQS_9
为节点v在第k次迭代中的预测概率分布,
Figure QLYQS_10
为节点v在初始化的预测概率分布,LP为标签传播的英文简称,∈为属于,
Figure QLYQS_11
为实数集合,Y为标签集合,|.|为集合的基数,
Figure QLYQS_12
为任意,VL为有标签节点集合,V为所有节点,L为有标签,VU为无标签节点集合,U为无标签,fLP为所述第一预测结果中的LP的最终预测。
3.如权利要求1所述的方法,其特征在于,所述将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型,包括:
将所述第一预测结果代入如下公式:
Figure QLYQS_13
优化所述学生模型的无标签点集的软标签的预测,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;
其中,K为所述学生模型的总层数,min为最小化,Θ为参数集合,fGNN(v)为节点v的教师模型软标签,
Figure QLYQS_14
为节点v的第k层CPF模型,CPF;Θ为模型及其参数,;为分隔符,‖.‖2为L2范数。
4.如权利要求1至3任一项所述的方法,其特征在于,采用如下步骤,所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果:
获取数据集,以及图结构中的节点集和边集;
通过所述图结构及所述数据集预训练,得到图神经网络教师模型;所述教师模型为图神经网络GNN分类器;
采用所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果。
5.一种基于图神经网络知识蒸馏的图节点分类装置,其特征在于,包括:
获取模块,用于获取图神经网络教师模型,数据集,以及图结构中的节点集和边集,其中,所述数据集包括:基准引文或购买商品;所述教师模型为图神经网络GNN分类器,并且通过所述图结构及所述数据集预训练得到的,所述教师模型用于预测无标签点集的软标签可能为硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果;所述数据集包括:具有硬标签的点集及无标签点集,所述每个节点表示数据集中的一个点;所述点为基准引文或购买商品;其中,所述点为基准引文时,所述基准引文包括:机器学习论文、计算机科学论文,所述每个节点表示一篇论文,所述每个边表示论文的引用关系,标签表示论文的研究领域;所述点为购买商品时,所述每个节点表示商品,所述每个边表示是否两个商品经常被共同购买,标签表示商品类别;
第一处理模块,用于基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果;所述学生模型是基于所有硬标签的分布,采用标签传播公式,预测有标签的节点传播到相邻无标签的节点的概率分布,以及基于每个节点的特征,采用特征变换公式,预测无标签点集的软标签进行组合训练的;
第二处理模块,用于将所述学生模型得到的第二预测结果,拟合所述教师模型得到的第一预测结果,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;所述训练好的学生模型用于对无标签点集的软标签的预测,得到基准引文的类型或购买商品的类型;
所述第一处理模块,用于:
基于每个节点的特征,根据如下公式:
cuTXu,将每个节点的特征映射为所述每个节点的置信度;
其中,cu为所述每个节点的置信度,u为任一节点名称,z为可学习参数,
Figure QLYQS_15
是一个可学习参数,∈为属于,
Figure QLYQS_16
为d维实数集合,d为维度大小,T为转置,Xu为任一节点u的特征;
根据如下公式:
Figure QLYQS_17
为每条边计算边权,其中,所述边为每两个节点之间的边;
其中,wuv为每两个节点u与节点v之间的边的边权,exp为指数函数运算符,cu为节点u的置信度,u′为从集合Nv中选取的任一节点,Nv为节点v和节点v的邻居组成的集合,∈为属于,∪为并集,{v}为节点v的集合,cu′为节点u′的置信度;
针对所有节点,根据如下特征变换公式:
fFT(v)=oftmax(MLP(Xv)),将每个节点的特征变换为所述无标签点集的软标签的预测;
其中,fFT(v)为特征变换函数,softmax为归一化函数,FT为特征变换,MLP(.)为多层感知器;
根据如下公式:
Figure QLYQS_18
更新每一层所述无标签点集的软标签的预测,得到所述无标签点集的软标签的第二预测结果;
其中,
Figure QLYQS_19
作为一个整体,表示标签传播公式,
Figure QLYQS_20
为对节点v执行第k+1层CPF函数,αv为平衡参数,u为节点名称,Nv为节点v的邻居,∈为属于,v为任一节点名称,∪为并集运算,{v}为节点v构成的集合,
Figure QLYQS_21
为对节点u执行第k层CPF函数,fFT(v)为对节点v执行FT函数。
6.如权利要求5所述的装置,其特征在于,所述装置还包括:初始化模块,用于:
在所述基于所述第一预测结果,所述数据集及所述图结构,采用学生模型预测无标签点集的软标签,可能为所述硬标签的第二概率,得到所述无标签点集的软标签的第二预测结果之前,根据如下公式:
Figure QLYQS_22
为每个节点初始化标签预测,以完成初始化学生模型;
其中,
Figure QLYQS_23
为节点v在第k次迭代中的预测概率分布,
Figure QLYQS_24
为节点v在初始化的预测概率分布,LP为标签传播的英文简称,∈为属于,
Figure QLYQS_25
为实数集合,Y为标签集合,|.|为集合的基数,
Figure QLYQS_26
为任意,VL为有标签节点集合,V为所有节点,L为有标签,VU为无标签节点集合,U为无标签,fLP为所述第一预测结果中的LP的最终预测。
7.如权利要求5所述的装置,其特征在于,所述第二处理模块,用于:
将所述第一预测结果代入如下公式:
Figure QLYQS_27
优化所述学生模型的无标签点集的软标签的预测,直至所述学生模型得到的第二预测结果与所述教师模型得到的第一预测结果最相近,得到训练好的学生模型;
其中,K为所述学生模型的总层数,min为最小化,Θ为参数集合,fGNN(v)为节点v的教师模型软标签,
Figure QLYQS_28
为节点v的第k层CPF模型,CPF;Θ为模型及其参数,;为分隔符,‖.‖2为L2范数。
8.如权利要求5至7任一项所述的装置,其特征在于,所述装置还包括:预测模块,用于:
获取数据集,以及图结构中的节点集和边集;
通过所述图结构及所述数据集预训练,得到图神经网络教师模型;所述教师模型为图神经网络GNN分类器;
采用所述教师模型预测无标签点集的软标签,可能为所述硬标签的第一概率,得到所述无标签点集的软标签的第一预测结果。
CN202110102108.1A 2021-01-26 2021-01-26 一种基于图神经网络知识蒸馏的图节点分类方法及装置 Active CN112861936B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110102108.1A CN112861936B (zh) 2021-01-26 2021-01-26 一种基于图神经网络知识蒸馏的图节点分类方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110102108.1A CN112861936B (zh) 2021-01-26 2021-01-26 一种基于图神经网络知识蒸馏的图节点分类方法及装置

Publications (2)

Publication Number Publication Date
CN112861936A CN112861936A (zh) 2021-05-28
CN112861936B true CN112861936B (zh) 2023-06-02

Family

ID=76009124

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110102108.1A Active CN112861936B (zh) 2021-01-26 2021-01-26 一种基于图神经网络知识蒸馏的图节点分类方法及装置

Country Status (1)

Country Link
CN (1) CN112861936B (zh)

Families Citing this family (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113343113A (zh) * 2021-07-05 2021-09-03 合肥工业大学 基于图卷积网络进行知识蒸馏的冷启动实体推荐方法
CN113610173B (zh) * 2021-08-13 2022-10-04 天津大学 一种基于知识蒸馏的多跨域少样本分类方法
CN113627545B (zh) * 2021-08-16 2023-08-08 山东大学 一种基于同构多教师指导知识蒸馏的图像分类方法及系统
CN113887698B (zh) * 2021-08-25 2024-06-14 浙江大学 基于图神经网络的整体知识蒸馏方法和系统
CN114092747A (zh) * 2021-11-30 2022-02-25 南通大学 基于深度元度量模型互学习的小样本图像分类方法
CN115761654B (zh) * 2022-11-11 2023-11-24 中南大学 一种车辆重识别方法
CN117237343B (zh) * 2023-11-13 2024-01-30 安徽大学 半监督rgb-d图像镜面检测方法、存储介质及计算机设备
CN118503435B (zh) * 2024-07-22 2024-10-11 浙江大学 基于知识融合的多未知领域文本分类方法、设备、介质

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110674880A (zh) * 2019-09-27 2020-01-10 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN111462137A (zh) * 2020-04-02 2020-07-28 中科人工智能创新技术研究院(青岛)有限公司 一种基于知识蒸馏和语义融合的点云场景分割方法
CN112183670A (zh) * 2020-11-05 2021-01-05 南开大学 一种基于知识蒸馏的少样本虚假新闻检测方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11741355B2 (en) * 2018-07-27 2023-08-29 International Business Machines Corporation Training of student neural network with teacher neural networks

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110674880A (zh) * 2019-09-27 2020-01-10 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN111462137A (zh) * 2020-04-02 2020-07-28 中科人工智能创新技术研究院(青岛)有限公司 一种基于知识蒸馏和语义融合的点云场景分割方法
CN112183670A (zh) * 2020-11-05 2021-01-05 南开大学 一种基于知识蒸馏的少样本虚假新闻检测方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
Zhang Mengmei 等.Adversarial Label-Flipping Attack and Defense for Graph Neural Networks.2020 IEEE International Conference on Data Mining (ICDM).2020,791-800. *

Also Published As

Publication number Publication date
CN112861936A (zh) 2021-05-28

Similar Documents

Publication Publication Date Title
CN112861936B (zh) 一种基于图神经网络知识蒸馏的图节点分类方法及装置
Liang et al. Text feature extraction based on deep learning: a review
Kim et al. Deep hybrid recommender systems via exploiting document context and statistics of items
CN110263227B (zh) 基于图神经网络的团伙发现方法和系统
CN104915386B (zh) 一种基于深度语义特征学习的短文本聚类方法
Yang et al. Variational co-embedding learning for attributed network clustering
Kumar et al. Multi-label classification using hierarchical embedding
CN112749274A (zh) 基于注意力机制和干扰词删除的中文文本分类方法
CN114741507B (zh) 基于Transformer的图卷积网络的引文网络分类模型建立及分类
CN112257841A (zh) 图神经网络中的数据处理方法、装置、设备及存储介质
Chauhan et al. Randomized neural networks for multilabel classification
Yu et al. Can machine learning paradigm improve attribute noise problem in credit risk classification?
CN113449853A (zh) 一种图卷积神经网络模型及其训练方法
Yu et al. PKGCN: prior knowledge enhanced graph convolutional network for graph-based semi-supervised learning
CN117349494A (zh) 空间图卷积神经网络的图分类方法、系统、介质及设备
Feng et al. Ontology semantic integration based on convolutional neural network
La Rosa et al. A self-interpretable module for deep image classification on small data
Hong et al. ProtoryNet-interpretable text classification via prototype trajectories
Xia et al. HatchEnsemble: an efficient and practical uncertainty quantification method for deep neural networks
Zhang et al. Dep-tsp meta: A multiple criteria dynamic ensemble pruning technique ad-hoc for time series prediction
Yilmaz Connectionist-symbolic machine intelligence using cellular automata based reservoir-hyperdimensional computing
Chen et al. Dropout training for SVMs with data augmentation
Nguyen et al. Graph-induced restricted Boltzmann machines for document modeling
Louati et al. Embedding channel pruning within the CNN architecture design using a bi-level evolutionary approach
Jia et al. An optimized classification algorithm by neural network ensemble based on PLS and OLS

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant