CN114707644A - 图神经网络的训练方法及装置 - Google Patents

图神经网络的训练方法及装置 Download PDF

Info

Publication number
CN114707644A
CN114707644A CN202210440602.3A CN202210440602A CN114707644A CN 114707644 A CN114707644 A CN 114707644A CN 202210440602 A CN202210440602 A CN 202210440602A CN 114707644 A CN114707644 A CN 114707644A
Authority
CN
China
Prior art keywords
node
neural network
nodes
classified
unmarked
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210440602.3A
Other languages
English (en)
Other versions
CN114707644B (zh
Inventor
胡斌斌
刘洪瑞
张志强
石川
王啸
周俊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing University of Posts and Telecommunications
Alipay Hangzhou Information Technology Co Ltd
Original Assignee
Beijing University of Posts and Telecommunications
Alipay Hangzhou Information Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing University of Posts and Telecommunications, Alipay Hangzhou Information Technology Co Ltd filed Critical Beijing University of Posts and Telecommunications
Priority to CN202210440602.3A priority Critical patent/CN114707644B/zh
Publication of CN114707644A publication Critical patent/CN114707644A/zh
Priority to US18/306,144 priority patent/US20230342606A1/en
Application granted granted Critical
Publication of CN114707644B publication Critical patent/CN114707644B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/20Information retrieval; Database structures therefor; File system structures therefor of structured data, e.g. relational data
    • G06F16/28Databases characterised by their database models, e.g. relational or object models
    • G06F16/284Relational databases
    • G06F16/288Entity relationship models
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • G06F18/24155Bayesian classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/0895Weakly supervised learning, e.g. semi-supervised or self-supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation
    • G06N5/022Knowledge engineering; Knowledge acquisition
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Databases & Information Systems (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本说明书实施例提供一种图神经网络的训练方法,涉及基于用户关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。

Description

图神经网络的训练方法及装置
技术领域
本说明书一个或多个实施例涉及机器学习技术领域,尤其涉及一种图神经网络的训练方法及装置。
背景技术
关系网络图是对现实世界中实体之间的关系的描述,目前被广泛应用于各种业务处理中,如社交网络分析、化学键预测等。图神经网络(Graph Neural Networks,简称GNN)适用于处理关系网络图上的各种任务,然而,GNN的性能在很大程度上依赖标注数据的数量,通常,GNN的性能会随着标注数据的减少而迅速下降。
因此,需要一种方案,能够突破GNN训练时标注数据不足的限制,得到性能优异的GNN模型,从而有效提升业务处理结果的准确度。
发明内容
本说明书一个或多个实施例描述了一种图神经网络的训练方法及装置,利用未标注数据扩充标注数据,并引入信息增益缩小原始标注数据分布与扩充后标注数据分布所对应训练损失之间的差异,从而有效提升GNN模型的训练效果。
根据第一方面,提供一种图神经网络的训练方法,涉及基于用户关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
在一个实施例中,所述多个用户节点中包括第二数量的未标注节点,各个分类预测向量中包括与多个类别对应的多个预测概率;其中,基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签,包括:针对所述第二数量的未标注节点中的各个节点,若其所对应分类预测向量中包含的最大预测概率达到预设标准,则将该节点归入所述第一数量的未标注节点,并将该最大预测概率所对应的类别确定为该节点的伪分类标签。
在一个实施例中,针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益,包括:针对任意的第一未标注节点,利用其对应的第一分类预测向量和伪分类标签,训练所述当前图神经网络,并基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量;根据所述第一分类预测向量,确定第一信息熵;根据所述第二分类预测向量,确定第二信息熵;基于所述第二信息熵与所述第一信息熵的差值,得到所述信息增益。
在一个具体的实施例中,所述第一图神经网络包括多个聚合层和输出层;其中,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量,包括:在所述多个聚合层中的某个聚合层,对上一聚合层输出的针对所述多个用户节点的多个聚合向量中的向量元素进行随机置零处理,并且,基于所述随机置零处理后的多个聚合向量,确定本聚合层针对所述多个用户节点输出的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。
在另一个具体的实施例中,所述第一图神经网络包括多个聚合层和输出层;其中,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量,包括:在所述多个聚合层中的某个聚合层,对所述用户关系图谱所对应邻接矩阵中的矩阵元素进行随机置零处理,并且,基于所述随机置零处理后的邻接矩阵,以及由上一聚合层输出的针对所述多个用户节点的多个聚合向量,确定本聚合层针对所述多个用户节点的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。
进一步,在一个更具体的实施例中,基于训练出的第一图神经网络确定该未标注节点的第二分类预测向量,包括:多次执行确定所述第二分类预测向量的操作,对应得到多个第二分类预测向量;其中,根据所述第二分类预测向量,确定第二信息熵,包括:将所述多个第二分类预测向量所对应多个信息熵的均值,确定为所述第二信息熵。
在一个实施例中,根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数,包括:根据所述各个标注节点对应的分类预测向量和真实分类标签,确定第一损失项;针对所述各个未标注节点,根据其对应的分类预测向量和伪分类标签,确定第二损失项,并利用其对应的信息增益对所述第二损失项进行加权处理;根据所述第一损失项和加权处理后的第二损失项,更新所述模型参数。
在一个具体的实施例中,利用其对应的信息增益对所述第二损失项进行加权处理,包括:利用所述第一数量的未标注节点所对应第一数量的信息增益,对所述各个未标注节点的信息增益进行归一化处理,得到对应的加权系数;利用所述加权系数进行所述加权处理。
根据第二方面,提供一种图神经网络的训练方法,涉及基于预先构建的关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:利用当前图神经网络对所述关系图谱进行处理,得到与该关系图谱中多个业务对象节点对应的多个分类预测向量;基于所述多个分类预测向量,为所述多个业务对象节点中第一数量的未标注节点分配对应的伪分类标签;针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;根据与所述多个业务对象节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
根据第三方面,提供一种图神经网络的训练装置,所述装置通过以下单元,根据用户关系图谱对图神经网络进行多轮次迭代更新中的任一轮次:分类预测单元,配置为利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;伪标签分配单元,配置为基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;信息增益确定单元,配置为针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;参数更新单元,配置为根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
根据第四方面,提供一种图神经网络的训练装置,所述装置通过以下单元,根据预先构建的关系图谱对图神经网络进行多轮次迭代更新中的任一轮次:分类预测单元,配置为利用当前图神经网络对所述关系图谱进行处理,得到与该关系图谱中多个业务对象节点对应的多个分类预测向量;伪标签分配单元,配置为基于所述多个分类预测向量,为所述多个业务对象节点中第一数量的未标注节点分配对应的伪分类标签;信息增益确定单元,配置为针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;参数更新单元,配置为根据与所述多个业务对象节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
根据第五方面,提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行第一方面或第二方面的方法。
根据第六方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,该处理器执行所述可执行代码时,实现第一方面或第二方面的方法。
采用本说明书实施例提供的方法和装置,利用用户关系图谱中的未标注数据扩充标注数据,并引入信息增益缩小原始标注数据分布与扩充后标注数据分布所对应训练损失之间的差异,从而有效提升GNN模型的训练效果,进而提高训练出的GNN模型对用户节点的预测准确度。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1示出根据一个实施例的图神经网络的训练框架示意图;
图2示出根据一个实施例的图神经网络的训练方法流程示意图;
图3示出根据一个实施例的确定信息增益的方法流程图;
图4示出根据另一个实施例的图神经网络的训练方法流程示意图;
图5示出根据一个实施例的图神经网络的训练装置结构示意图;
图6示出根据另一个实施例的图神经网络的训练装置结构示意图。
具体实施方式
下面结合附图,对本说明书提供的方案进行描述。
承前所述,需要一种能够突破GNN训练时标注数据不足的限制的方案。由此,提出一种思路,采用自训练(self-training)的方式,通过充分利用丰富的未标注数据来解决标注数据的稀缺性问题。具体,给定一个在原始的标注数据集
Figure BDA0003614932360000041
上训练出的模型作为教师模型,对未标注数据集
Figure BDA0003614932360000042
进行预测,接着,利用其中高置信度的预测结果为对应的未标注数据子集
Figure BDA0003614932360000043
打上伪标签,实现对原始标注数据的扩充,然后,利用扩充后的标注数据集
Figure BDA0003614932360000045
Figure BDA0003614932360000044
训练学生模型,并利用训练得到的学生模型更新教师模型,如此往复迭代,直到学生模型收敛。
上述自训练方式的关键在于,对高置信度的未标注样本进行伪标注以实现对标注数据的扩充。然而,发明人通过实验和分析发现,利用高置信度的未标注样本进行扩充而得到的扩充后标注数据集
Figure BDA0003614932360000051
相较原始的标注数据集
Figure BDA00036149323600000510
发生了分布迁移(distributionshift),导致利用扩充数据集
Figure BDA0003614932360000052
训练出的GNN模型的性能欠佳,难以得到足够清晰、鲁邦的决策边界。进一步,从损失函数的角度进行分析,将原始标注数据集
Figure BDA0003614932360000053
服从的数据分布记作Ppop,给定一个参数记为θ的分类器fθ,于是,可以通过最小化下式(1)示意的损失函数,得到模型参数θ的最优设定。
Figure BDA0003614932360000054
在上式(1)中,vi和yi分别表示服从Ppop分布的第i个标注节点的节点特征和节点标签;pi表示分类器fθ针对第i个标注节点输出的预测结果;l(·,·)表示多分类损失,例如,可以是交叉熵损失。
类似地,对于上述存在分布迁移的自训练场景,可以采用下示损失函数计算训练损失:
Figure BDA0003614932360000055
上式中,vu和yu分别表示服从Pst分布的第u个未标注节点的节点特征和真实节点标签(实际未获取到);
Figure BDA00036149323600000511
表示第u个未标注节点的伪标签;pu表示分类器fθ针对第u个未标注节点输出的预测结果。
通过对上述公式(1)和(2)进行对比分析可以得出,自训练过程中的分布迁移会严重影响图模型的训练性能,进而导致图模型在预测阶段泛化性能的恶化。因此,相较采用公式(2),采用公式(1)计算出的训练损失去优化分类器fθ更为理想。然而,在实际应用中,因标注数据稀少而难以准确还原真实的标注数据分布,只有公式(2)中计算出的Lst是可用的。基于此,为了减小甚至消除Lst和Lpop之间的差距,发明人提出以下定理:
给定公式(1)和(2)中分别定义的损失Lpop和Lst,假定对于伪标注数据集
Figure BDA00036149323600000512
中的每个节点vu都存在
Figure BDA0003614932360000056
那么Lst=Lpop成立,当Lst可以被记作含有额外的权重系数
Figure BDA0003614932360000057
Figure BDA0003614932360000058
的下式:
Figure BDA0003614932360000059
上述定理的证明过程如下:
首先,根据针对节点vu的假定:
Figure BDA0003614932360000061
可以将公式(1)重写为以下形式:
Figure BDA0003614932360000062
注意到:
Figure BDA0003614932360000063
由此,可以将公式(4)重写为:
Figure BDA0003614932360000064
其中,γu可以被当作未标注节点vu的损失函数
Figure BDA0003614932360000065
的权重。
最后,回顾公式(2)中示出的分布迁移情况下的损失函数,可以发现,公式(2)中的Lst可以被记作在公式(1)中添加额外的权重系数γu的形式。换言之,只要能够在Lst中为每个伪标注节点添加一个合适的系数γu,就能够使得Lst逼近Lpop
然而,因为标注数据分布Ppop通常是难求解的,这就意味着权重系数γu难以被精确求解。进一步,发明人采用可视化等手段发现,未标注节点vu对应的权重系数γu和信息增益
Figure BDA0003614932360000069
具有相同的变化趋势,具体地,离决策边界越远,二者的取值越小,因此,提出通过求解信息增益
Figure BDA00036149323600000610
去近似权重系数γu。简单来说,信息增益
Figure BDA00036149323600000611
是对未标注节点vu针对模型优化贡献的衡量。
由此,发明人提出另一种实施方式,通过引入信息增益
Figure BDA00036149323600000612
,对自训练损失Lst中针对未标注节点vu的损失项
Figure BDA0003614932360000066
进行加权,以使得Lst逼近或等于Lpop。为便于直观理解,图1示出根据一个实施例的图神经网络的训练框架示意图。如图1所示,用户关系网络图的图数据中,包括原始的标注样本集
Figure BDA0003614932360000067
和未标注样本集
Figure BDA0003614932360000068
基于此,在任一轮次的图神经网络迭代训练中,先利用当前的GNN模型处理关系网络图,得到未标注样本集
Figure BDA0003614932360000071
中每个未标注节点的分类结果和置信度,选取置信度足够高的未标注节点加入未标注样本子集
Figure BDA0003614932360000072
并确定其中各个未标注节点vu的信息增益,进而利用标注样本集
Figure BDA0003614932360000073
以及未标注样本子集
Figure BDA0003614932360000074
和对应确定出的信息增益,确定训练损失,用以更新GNN模型。
下面结合更多实施例,描述实现上述发明构思的方案实施步骤。图2示出根据一个实施例的图神经网络的训练方法流程示意图,所述方法的执行主体可以为任何具有计算、处理能力的装置、平台或设备集群。
图2示出的训练方法涉及基于用户关系图谱(或称用户关系网络图),对图神经网络进行多轮次迭代更新。为便于理解,先对用户关系图谱进行基本介绍。
用户关系图谱中包括与多个用户对应的多个用户节点,以及用户节点之间存在关联关系而形成的连接边。用户节点的节点特征可以包括对应用户的静态特征(或称基础属性特征)和行为特征。在一个实施例中,用户静态特征可以包括用户性别、年龄、职业、常驻地、兴趣爱好等。在一个实施例中,用户行为特征可以包括消费频次、消费金额、消费时段、消费类别、社交网站发布的图文内容和社交活跃度,等等。
上述多个用户节点包含小部分携带用户类别标签的标注节点,以及大量没有携带标签的未标注节点。通常,其中的标注节点携带的标签是耗费昂贵的人工成本进行打标而得到。用户类别标签与具体的预测任务相适应,在一个实施例中,预测任务是用户风险评估,相应,用户类别标签可以包括有风险用户和无风险用户,或者,包括高风险用户、低风险用户和中风险用户,又或者,包括违约用户和守信用户,再或者,包括欺诈用户和安全用户。在另一个实施例中,预测任务是消费人群划分,相应,用户类别标签可以包括高消费人群和低消费人群。
以上,对用户关系图谱进行基本介绍。如图2所示,上述多轮次迭代更新中任意一轮次的迭代更新包括以下步骤:
步骤S210,利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;步骤S220,基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;步骤S230,针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;步骤S240,根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
对以上步骤的展开介绍如下:
首先,在步骤S210,利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量。在一个实施例中,本轮迭代为首轮,相应,当前图神经网络可以是进行参数初始化设定后的图神经网络,或者,可以是利用多个标注节点及其携带的标签,对参数初始化后的图神经网络进行训练后得到的图神经网络。在另一个实施例中,本轮迭代非首轮,相应,当前图神经网络可以是经由上一轮迭代更新后得到的图神经网络。
当前图神经网络中包括多个聚合层和输出层,其中多个聚合层用于对用户关系图谱进行图嵌入处理,得到对应上述多个用户节点的多个节点嵌入向量。需理解,多个聚合层中首个聚合层的输入包括用户节点和/或连接边的原始特征,多个聚合层基于该原始特征进行节点的高阶表征,从而得到具有深度语义的节点表征向量(或称节点嵌入向量)。进一步,输出层用于根据各个节点嵌入向量,输出对应用户节点的分类预测结果。
在一个实施例中,当前图神经网络的类型为图卷积神经网络(GraphConvolutional Network,简称GCN),相应,GCN中任意的第l个聚合层的输出H(l)可以通过下式计算:
Figure BDA0003614932360000081
在上式(7)中,A表示用户关系图谱的邻接矩阵,用于记录用户节点之间的连接关系,示例性地,对于邻阶矩阵A中任意的元素Aij,当其数值为1或0时,分别表示用户节点i和用户节点j之间存在和不存在连接边;
Figure BDA0003614932360000082
表示归一化算子;W(l)表示第l个聚合层中的参数矩阵,且
Figure BDA0003614932360000083
表示第(l-1)个聚合层的输出。需理解,
Figure BDA0003614932360000084
其中X表示上述多个用户节点的节点特征形成的特征矩阵,
Figure BDA0003614932360000088
表示上述多个用户节点形成的节点集合,
Figure BDA0003614932360000085
表示节点集合中节点的个数,Dv表示节点特征的维数。另外,
Figure BDA0003614932360000086
GCN模型的参数
Figure BDA0003614932360000087
L为多个聚合层的总层数。
在另一个实施例中,当前图神经网络的类型还可以为图注意力网络(GraphAttention Network,简称GAT)等,需理解,图神经网络的已有类型繁多,在本说明书披露的实施例中可以按需选取,具体不作限定。
另一方面,上述输出层包括一个或多个全连接网络子层,利用此全连接网络子层,可以对对各个节点嵌入向量分别进行线性变换和/或非线性变换处理,从而得到对应用户节点的分类预测向量,该分类预测向量中的多个向量元素对应多个类别概率。
由上,可以得到多个用户节点对应的多个分类预测向量。接着,在步骤S220,基于该多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签。
为区分描述,将多个用户节点中所有未标注节点的数量记作第二数量。具体,在本步骤中,可以基于第二数量的未标注节点对应的分类预测向量,为其中的部分或全部未标注节点分配对应的伪分类标签。
在一个实施例中,针对第二数量的未标注节点中的各个节点,将其所对应分类预测向量中最大预测概率所对应的类别确定为该节点的伪分类标签。如此,可以实现为第二数量的未标注节点分配对应的伪分类标签,此时,上述第一数量等于第二数量。
在另一个实施例中,针对第二数量的未标注节点中的各个节点,若其所对应分类预测向量中包含的最大预测概率达到预设标准,则将该节点归入上述第一数量的未标注节点,并将该最大预测概率所对应的类别确定为该节点的伪分类标签。在一个具体的实施例中,其中预设标准为:最大预测概率大于预设阈值(例如,0.2)。在另一个具体的实施例中,其中预设标准为:该节点对应的最大预测概率在第二数量的最大预测概率中排在前k(例如,k=1000)位。如此,可以实现从全量未标注节点中选取高置信度(置信度等于最大预测概率)的未标注节点,并为其打上伪标签,此时,第一数量小于第二数量。
以上,可以实现对第一数量的未标注节点的自动打标。为清楚描述,本说明书实施例中将第一数量的未标注节点形成的未标注子集记作
Figure BDA0003614932360000091
然后,在步骤S230,针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练当前图神经网络而产生的信息增益。需理解,在概率论或信息论中,信息增益是指在为随机事件(例如,明天是否会下雨)中的某个随机变量(例如,明天的天气)赋予具体的变量值(例如,阴天)后,该随机事件信息量的减小量。其中信息量通常是计算香农熵(Shannon’s entropy),或称信息熵而得到。根据信息增益的定义,对于未标注子集
Figure BDA0003614932360000092
中任一的未标注节点vu,可以利用通过预测分布(predictive distribution)和后验参数
Figure BDA0003614932360000093
计算其对GNN模型参数θ的信息增益
Figure BDA0003614932360000094
具体可参见下式:
Figure BDA0003614932360000095
在上式中,右侧第1项是在后验参数
Figure BDA0003614932360000096
下预测分布
Figure BDA0003614932360000097
的信息熵的期望值,用来度量模型参数θ没有发生改变时的信息量,
Figure BDA0003614932360000098
表示上述用户关系图谱,yu表示GNN模型fθ输出的类别概率向量;第2项是给定节点特征xu下条件熵的平均值(或称期望值),用来捕捉利用节点vu优化模型fθ之后模型参数θ的信息量。如此,通过计算这两项之间的差值,可以度量未标注节点vu为模型参数θ带来的信息增益。
观察公式(8)可知,若采用其计算信息增益
Figure BDA00036149323600000910
,则需要计算后验参数
Figure BDA0003614932360000099
然而,后验参数
Figure BDA0003614932360000101
通常是难以求解的,在一种可能的方式中,可以采用传统贝叶斯网络计算后验参数
Figure BDA0003614932360000102
但是,这将带来巨大的计算消耗。
由此提出另一种方式,通过少量计算即可获得较为精确的信息增益计算值。具体,采用dropout或dropedge算法,实现对后验参数
Figure BDA0003614932360000103
的近似。下面,结合图3,对基于dropout算法或dropedge算法确定信息增益的方式进行介绍。如图3所示,包括实现以下步骤:
步骤S31,针对未标注节点子集
Figure BDA0003614932360000104
中任意的未标注节点vu(或称第一未标注节点),利用上述多个分类预测向量中与之对应的第一分类预测向量,以及其对应的伪分类标签,训练当前图神经网络,得到第一图神经网络。具体地,利用第一分类预测向量和对应的伪分类标签计算训练损失,再利用此训练损失优化(或称更新)当前图神经网络中的参数,得到更新后的第一图神经网络。
步骤S32,基于训练出的第一图神经网络确定该未标注节点vu的第二分类预测向量。
在一个实施例中,引入dropout算法,对用户节点特征进行随机屏蔽(或称随机置零)。具体,在第一图神经网络所包含多个聚合层中的某个聚合层,对上一聚合层输出的针对多个用户节点的多个聚合向量中的向量元素进行随机置零处理,并且,基于所述随机置零处理后的多个聚合向量,确定本聚合层针对所述多个用户节点输出的多个聚合向量。
在一个具体的实施例中,上述某个聚合层可以由工作人员预先指定或随机设定,例如,可以指定在多个聚合层中的最后一个聚合层进行dropout操作。在一个具体的实施例中,对向量元素执行置零处理的聚合层不限于1个,还可以为其他数量,例如,可以在各个聚合层中均执行节点特征的dropout操作。
另一方面,在一个具体的实施例中,在当前图神经网络为GCN为的情况下,可以将上述某个聚合层中基于dropout算法对上一层输出的多个聚合向量进行处理以得到本层输出的过程记作下式:
Figure BDA0003614932360000105
在上式中,H(l-1)表示上一层输出的多个聚合向量形成的矩阵;
Figure BDA0003614932360000106
其中的Dl-1×Dl个矩阵元素可以通过从伯努利分布进行多次采样而得到,各个矩阵元素指示是否将矩阵H(l-1)中对应位置的矩阵元素置零;运算符⊙表示对两个矩阵之间具有相同位置的元素进行相乘运算。
进一步,在输出层,对最后一个聚合层针对未标注节点vu输出的聚合向量H(L)进行处理,得到第二分类预测向量
Figure BDA0003614932360000107
或者,也可以对多个聚合层针对未标注节点vu输出的多个聚合向量
Figure BDA0003614932360000108
进行平均处理,得到第二分类预测向量
Figure BDA0003614932360000109
注意到,公式(9)中的运算项H(l-1)⊙Z(l-1)实现在节点特征上的伯努利采样,相当于从后验参数
Figure BDA0003614932360000112
所符合的参数分布中进行采样。因此,为估计后验参数
Figure BDA0003614932360000113
可以多次执行上述预测操作以对该参数分布进行多次(记作T次)采样,相应,每次(记作第t次)采样都会得到对应的第二分类预测向量
Figure BDA0003614932360000114
如此,可以基于dropout算法得到T个第二分类预测向量
Figure BDA0003614932360000115
在另一个实施例中,引入dropedge算法,对用户节点之间的连接边进行随机屏蔽。具体,在第一图神经网络所包含多个聚合层中的某个聚合层,对用户关系图谱所对应邻接矩阵A中的矩阵元素进行随机置零处理,并且,基于所述随机置零处理后的邻接矩阵,以及由上一聚合层输出的针对上述多个用户节点的多个聚合向量,确定本聚合层针对该多个用户节点的多个聚合向量。
在一个具体的实施例中,上述某个聚合层可以由工作人员预先指定或随机设定,实际应用中,可以将上述某个聚合层指定为多个聚合层中的最后一个聚合层。在一个具体的实施例中,对邻接矩阵元素执行置零处理的聚合层不限于1个,还可以为其他数量,例如,可以在各个聚合层中均执行边特征的dropedge操作。
另一方面,在一个具体的实施例中,在当前图神经网络为GCN为的情况下,可以将上述某个聚合层中基于dropedge算法对上一层输出的多个聚合向量进行处理以得到本层输出的过程记作下式:
Figure BDA0003614932360000116
在上式中,H(l-1)表示上一层输出的多个聚合向量形成的矩阵;
Figure BDA0003614932360000117
其中的
Figure BDA0003614932360000118
个矩阵元素可以通过从伯努利分布进行多次采样而得到,各个矩阵元素指示是否将邻接矩阵A中对应位置的矩阵元素置零。
进一步,在输出层,对最后一个聚合层针对未标注节点vu输出的聚合向量H(L)进行处理,得到第二分类预测向量
Figure BDA0003614932360000119
或者,也可以对多个聚合层针对未标注节点vu输出的多个聚合向量
Figure BDA00036149323600001110
进行平均处理,得到第二分类预测向量
Figure BDA00036149323600001111
注意到,公式(10)中的运算项A⊙Z(l)实现在连接边上的伯努利采样,相当于从后验参数
Figure BDA00036149323600001113
所符合的参数分布中进行采样。因此,为估计后验参数
Figure BDA00036149323600001114
可以多次执行上述预测操作以对该参数分布进行多次(记作T次)采样,相应,每次(记作第t次)采样都会得到对应的第二分类预测向量
Figure BDA00036149323600001115
如此,可以基于dropedge算法得到T个第二分类预测向量
Figure BDA00036149323600001116
由上,可以基于dropout算法或dropedge算法得到T个第二分类预测向量
Figure BDA00036149323600001117
步骤S33,利用基于第一分类预测向量而确定的第一信息熵,减去基于第二分类预测向量而确定的第二信息熵,得到利用第一未标注节点训练当前图神经网络的信息增益。
在一个实施例中,可以对得到的上述T个第二分类预测向量
Figure BDA0003614932360000121
进行求平均,以得到针对未标注节点vu的预测向量的期望:
Figure BDA0003614932360000122
由此,可以采用下式计算与未标注节点vu对应的信息增益
Figure BDA0003614932360000123
Figure BDA0003614932360000124
在上式中,右侧的第1项表示上述第一信息熵,第2项的相反数表示上述第二信息熵。具体,D表示分类预测向量的维数,也即类别总数;
Figure BDA0003614932360000125
表示第一分类预测向量中与第d个类别对应的预测概率;
Figure BDA0003614932360000126
表示第t个第二分类预测向量中与第d个类别对应的预测概率。
由上,可以实现基于dropout算法或dropedge算法,确定未标注节点vu为模型参数带来的信息增益
Figure BDA0003614932360000127
另一方面,不那么优选地,在一个实施例中,在上述步骤S32中,可以不引入dropout或dropedge算法,而是直接利用第一图神经网络中的未被置零处理的参数处理用户关系图谱,得到未标注节点vu的第二分类预测向量,从而根据此第二分类预测向量计算第二信息熵;在另一个实施例中,在上述步骤S32中,公式(12)中的参数采样次数T也可以取1。
由上,可以确定未标注子集
Figure BDA0003614932360000128
中各个未标注节点vu可以为当前GNN模型参数带来的信息增益
Figure BDA0003614932360000129
之后,在步骤S240,根据与上述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
具体,一方面,根据上述各个标注节点对应的分类预测向量和真实分类标签,确定第一损失项;另一方面,针对上述各个未标注节点,根据其对应的分类预测向量和伪分类标签,确定第二损失项,并利用其对应的信息增益对所述第二损失项进行加权处理。进一步地,根据该第一损失项和加权处理后的第二损失项确定综合损失,从而根据此综合损失更新当前图神经网络中的模型参数。
在一个实施例中,上述加权处理包括:利用上述第一数量的未标注节点所对应第一数量的信息增益,对上述各个未标注节点的信息增益进行归一化处理,得到对应的加权系数;利用此加权系数进行针对第二损失项的加权处理。
根据一个示例,可以采用下式计算上述综合损失:
Figure BDA0003614932360000131
在上式(13)中,
Figure BDA0003614932360000132
表示未标注子集
Figure BDA0003614932360000133
中第i个节点的信息增益。如此,可以实现利用信息增益
Figure BDA0003614932360000134
的归一化结果
Figure BDA0003614932360000135
近似公式(3)的权重系数γu,从而得到逼近损失Lpop的Lst
进一步,可以利用确定出的综合损失计算训练梯度,进而根据此训练梯度,采用反向传播法更新当前图神经网络模型中的模型参数。
综上,采用本说明书实施例披露的图神经网络的训练方法,利用用户关系图谱中的未标注数据扩充标注数据,并引入信息增益缩小原始标注数据分布与扩充后标注数据分布所对应训练损失之间的差异,从而有效提升GNN模型的训练效果,进而提高训练出的GNN模型对用户节点的预测准确度。
以上,对训练用于处理用户关系网络图的图神经网络的方法进行介绍,实际,上述方法还可以拓展到训练关联其他业务对象的关系网络图的图神经网络。图4示出根据另一个实施例的图神经网络的训练方法流程示意图,所述方法的执行主体可以是任何具有计算、处理能力的装置、服务器或设备集群。
图2示出的训练方法涉及基于关系图谱对图神经网络进行多轮次迭代更新,该关系图谱中包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边。在一个实施例中,该多个业务对象为多个商品,进一步,商品节点的特征可以包括:类别、产地、成本、售价等,商品节点涉及的标签可以是商品热门等级,如热门商品或冷门商品。在另一个实施例中,该多个业务对象为多篇论文,进一步,论文节点特征可以包括:论文名称、关键字、摘要等,论文节点涉及的标签可以为论文所属领域,如生物、化学、物理、计算机等。
如图4所示,所述方法包括以下步骤:
步骤S410,利用当前图神经网络对所述关系图谱进行处理,得到与该关系图谱中多个业务对象节点对应的多个分类预测向量;步骤S420,基于所述多个分类预测向量,为所述多个业务对象节点中第一数量的未标注节点分配对应的伪分类标签;步骤S430,针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;步骤S440,根据与所述多个业务对象节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
需说明,对图4示出的方法步骤的描述,可以参见前述实施例中对图2示出的方法步骤的描述,在此不作赘述。
综上,采用本说明书实施例披露的图神经网络的训练方法,利用关系图谱中的未标注数据扩充标注数据,并引入信息增益缩小原始标注数据分布与扩充后标注数据分布所对应训练损失之间的差异,从而有效提升GNN模型的训练效果,进而提高训练出的GNN模型对业务对象节点的预测准确度。
与上述训练方法相对应的,本说明书实施例还披露训练装置。图5示出根据一个实施例的图神经网络的训练装置结构示意图,所述装置500通过以下单元,根据用户关系图谱对图神经网络进行多轮次迭代更新中的任一轮次:
分类预测单元510,配置为利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;伪标签分配单元520,配置为基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;信息增益确定单元530,配置为针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;参数更新单元540,配置为根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
在一个实施例中,所述多个用户节点中包括第二数量的未标注节点,各个分类预测向量中包括与多个类别对应的多个预测概率;所述伪标签分配单元520具体配置为:针对所述第二数量的未标注节点中的各个节点,若其所对应分类预测向量中包含的最大预测概率达到预设标准,则将该节点归入所述第一数量的未标注节点,并将该最大预测概率所对应的类别确定为该节点的伪分类标签。
在一个实施例中,所述信息增益确定单元530包括:训练子单元531,配置为针对任意的第一未标注节点,利用其对应的第一分类预测向量和伪分类标签,训练所述当前图神经网络;预测子单元532,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量;信息熵确定子单元533,配置为根据所述第一分类预测向量,确定第一信息熵,以及根据所述第二分类预测向量,确定第二信息熵;增益确定子单元534,配置为基于所述第二信息熵与所述第一信息熵的差值,得到所述信息增益。
进一步,在一个具体的实施例中,所述第一图神经网络包括多个聚合层和输出层;所述预测子单元532具体配置为:在所述多个聚合层中的某个聚合层,对上一聚合层输出的针对所述多个用户节点的多个聚合向量中的向量元素进行随机置零处理,并且,基于所述随机置零处理后的多个聚合向量,确定本聚合层针对所述多个用户节点输出的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。
在另一个具体的实施例中,所述第一图神经网络包括多个聚合层和输出层;所述预测子单元532具体配置为:在所述多个聚合层中的某个聚合层,对所述用户关系图谱所对应邻接矩阵中的矩阵元素进行随机置零处理,并且,基于所述随机置零处理后的邻接矩阵,以及由上一聚合层输出的针对所述多个用户节点的多个聚合向量,确定本聚合层针对所述多个用户节点的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。
进一步,在一个更具体的实施例中,所述预测子单元532进一步配置为:多次执行确定所述第二分类预测向量的操作,对应得到多个第二分类预测向量;所述信息熵确定子单元533具体配置为:将所述多个第二分类预测向量所对应多个信息熵的均值,确定为所述第二信息熵。
在一个实施例中,参数更新单元540配置为:根据所述各个标注节点对应的分类预测向量和真实分类标签,确定第一损失项;针对所述各个未标注节点,根据其对应的分类预测向量和伪分类标签,确定第二损失项,并利用其对应的信息增益对所述第二损失项进行加权处理;根据所述第一损失项和加权处理后的第二损失项,更新所述模型参数。
在一个具体的实施例中,参数更新单元540配置为进行上述加权处理,具体包括:利用所述第一数量的未标注节点所对应第一数量的信息增益,对所述各个未标注节点的信息增益进行归一化处理,得到对应的加权系数;利用所述加权系数进行所述加权处理。
综上,采用本说明书实施例披露的图神经网络的训练装置,利用用户关系图谱中的未标注数据扩充标注数据,并引入信息增益缩小原始标注数据分布与扩充后标注数据分布所对应训练损失之间的差异,从而有效提升GNN模型的训练效果,进而提高训练出的GNN模型对用户节点的预测准确度。
图6示出根据另一个实施例的图神经网络的训练装置结构示意图,如图6所示,所述装置600通过以下单元,根据预先构建的关系图谱对图神经网络进行多轮次迭代更新中的任一轮次:
分类预测单元610,配置为利用当前图神经网络对所述关系图谱进行处理,得到与该关系图谱中多个业务对象节点对应的多个分类预测向量。伪标签分配单元620,配置为基于所述多个分类预测向量,为所述多个业务对象节点中第一数量的未标注节点分配对应的伪分类标签。信息增益确定单元630,配置为针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益。参数更新单元640,配置为根据与所述多个业务对象节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
在一个实施例中,所述多个业务对象节点中包括第二数量的未标注节点,各个分类预测向量中包括与多个类别对应的多个预测概率;所述伪标签分配单元620具体配置为:针对所述第二数量的未标注节点中的各个节点,若其所对应分类预测向量中包含的最大预测概率达到预设标准,则将该节点归入所述第一数量的未标注节点,并将该最大预测概率所对应的类别确定为该节点的伪分类标签。
在一个实施例中,所述信息增益确定单元630包括:训练子单元631,配置为针对任意的第一未标注节点,利用其对应的第一分类预测向量和伪分类标签,训练所述当前图神经网络;预测子单元632,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量;信息熵确定子单元633,配置为根据所述第一分类预测向量,确定第一信息熵,以及根据所述第二分类预测向量,确定第二信息熵;增益确定子单元634,配置为基于所述第二信息熵与所述第一信息熵的差值,得到所述信息增益。
进一步,在一个具体的实施例中,所述第一图神经网络包括多个聚合层和输出层;所述预测子单元632具体配置为:在所述多个聚合层中的某个聚合层,对上一聚合层输出的针对所述多个业务对象节点的多个聚合向量中的向量元素进行随机置零处理,并且,基于所述随机置零处理后的多个聚合向量,确定本聚合层针对所述多个业务对象节点输出的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注业务对象节点输出的聚合向量进行处理,得到所述第二分类预测向量。
在另一个具体的实施例中,所述第一图神经网络包括多个聚合层和输出层;所述预测子单元632具体配置为:在所述多个聚合层中的某个聚合层,对所述业务对象关系图谱所对应邻接矩阵中的矩阵元素进行随机置零处理,并且,基于所述随机置零处理后的邻接矩阵,以及由上一聚合层输出的针对所述多个业务对象节点的多个聚合向量,确定本聚合层针对所述多个业务对象节点的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注业务对象节点输出的聚合向量进行处理,得到所述第二分类预测向量。
进一步,在一个更具体的实施例中,所述预测子单元632进一步配置为:多次执行确定所述第二分类预测向量的操作,对应得到多个第二分类预测向量;所述信息熵确定子单元633具体配置为:将所述多个第二分类预测向量所对应多个信息熵的均值,确定为所述第二信息熵。
在一个实施例中,参数更新单元640配置为:根据所述各个标注节点对应的分类预测向量和真实分类标签,确定第一损失项;针对所述各个未标注节点,根据其对应的分类预测向量和伪分类标签,确定第二损失项,并利用其对应的信息增益对所述第二损失项进行加权处理;根据所述第一损失项和加权处理后的第二损失项,更新所述模型参数。
在一个具体的实施例中,参数更新单元640配置为进行上述加权处理,具体包括:利用所述第一数量的未标注节点所对应第一数量的信息增益,对所述各个未标注节点的信息增益进行归一化处理,得到对应的加权系数;利用所述加权系数进行所述加权处理。
综上,采用本说明书实施例披露的图神经网络的训练装置,利用业务对象关系图谱中的未标注数据扩充标注数据,并引入信息增益缩小原始标注数据分布与扩充后标注数据分布所对应训练损失之间的差异,从而有效提升GNN模型的训练效果,进而提高训练出的GNN模型对业务对象节点的预测准确度。
根据另一方面的实施例,还提供一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行结合图2或图3所描述的方法。
根据再一方面的实施例,还提供一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现结合图2或图3所描述的方法。本领域技术人员应该可以意识到,在上述一个或多个示例中,本发明所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本发明的保护范围之内。

Claims (13)

1.一种图神经网络的训练方法,涉及基于用户关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:
利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;
基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;
针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;
根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
2.根据权利要求1所述的方法,其中,所述多个用户节点中包括第二数量的未标注节点,各个分类预测向量中包括与多个类别对应的多个预测概率;
其中,基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签,包括:
针对所述第二数量的未标注节点中的各个节点,若其所对应分类预测向量中包含的最大预测概率达到预设标准,则将该节点归入所述第一数量的未标注节点,并将该最大预测概率所对应的类别确定为该节点的伪分类标签。
3.根据权利要求1所述的方法,其中,针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益,包括:
针对任意的第一未标注节点,利用其对应的第一分类预测向量和伪分类标签,训练所述当前图神经网络,并基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量;
根据所述第一分类预测向量,确定第一信息熵;
根据所述第二分类预测向量,确定第二信息熵;
基于所述第二信息熵与所述第一信息熵的差值,得到所述信息增益。
4.根据权利要求3所述的方法,其中,所述第一图神经网络包括多个聚合层和输出层;其中,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量,包括:
在所述多个聚合层中的某个聚合层,对上一聚合层输出的针对所述多个用户节点的多个聚合向量中的向量元素进行随机置零处理,并且,基于所述随机置零处理后的多个聚合向量,确定本聚合层针对所述多个用户节点输出的多个聚合向量;
在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。
5.根据权利要求3所述的方法,其中,所述第一图神经网络包括多个聚合层和输出层;其中,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量,包括:
在所述多个聚合层中的某个聚合层,对所述用户关系图谱所对应邻接矩阵中的矩阵元素进行随机置零处理,并且,基于所述随机置零处理后的邻接矩阵,以及由上一聚合层输出的针对所述多个用户节点的多个聚合向量,确定本聚合层针对所述多个用户节点的多个聚合向量;
在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。
6.根据权利要求4或5所述的方法,其中,基于训练出的第一图神经网络确定该未标注节点的第二分类预测向量,包括:
多次执行确定所述第二分类预测向量的操作,对应得到多个第二分类预测向量;
其中,根据所述第二分类预测向量,确定第二信息熵,包括:
将所述多个第二分类预测向量所对应多个信息熵的均值,确定为所述第二信息熵。
7.根据权利要求1所述的方法,其中,根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数,包括:
根据所述各个标注节点对应的分类预测向量和真实分类标签,确定第一损失项;
针对所述各个未标注节点,根据其对应的分类预测向量和伪分类标签,确定第二损失项,并利用其对应的信息增益对所述第二损失项进行加权处理;
根据所述第一损失项和加权处理后的第二损失项,更新所述模型参数。
8.根据权利要求7所述的方法,其中,利用其对应的信息增益对所述第二损失项进行加权处理,包括:
利用所述第一数量的未标注节点所对应第一数量的信息增益,对所述各个未标注节点的信息增益进行归一化处理,得到对应的加权系数;
利用所述加权系数进行所述加权处理。
9.一种图神经网络的训练方法,涉及基于预先构建的关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:
利用当前图神经网络对所述关系图谱进行处理,得到与该关系图谱中多个业务对象节点对应的多个分类预测向量;
基于所述多个分类预测向量,为所述多个业务对象节点中第一数量的未标注节点分配对应的伪分类标签;
针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;
根据与所述多个业务对象节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
10.一种图神经网络的训练装置,所述装置通过以下单元,根据用户关系图谱对图神经网络进行多轮次迭代更新中的任一轮次:
分类预测单元,配置为利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;
伪标签分配单元,配置为基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;
信息增益确定单元,配置为针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;
参数更新单元,配置为根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
11.一种图神经网络的训练装置,所述装置通过以下单元,根据预先构建的关系图谱对图神经网络进行多轮次迭代更新中的任一轮次:
分类预测单元,配置为利用当前图神经网络对所述关系图谱进行处理,得到与该关系图谱中多个业务对象节点对应的多个分类预测向量;
伪标签分配单元,配置为基于所述多个分类预测向量,为所述多个业务对象节点中第一数量的未标注节点分配对应的伪分类标签;
信息增益确定单元,配置为针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;
参数更新单元,配置为根据与所述多个业务对象节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
12.一种计算机可读存储介质,其上存储有计算机程序,其中,当所述计算机程序在计算机中执行时,令计算机执行权利要求1-9中任一项所述的方法。
13.一种计算设备,包括存储器和处理器,其中,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-9中任一项所述的方法。
CN202210440602.3A 2022-04-25 2022-04-25 图神经网络的训练方法及装置 Active CN114707644B (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202210440602.3A CN114707644B (zh) 2022-04-25 2022-04-25 图神经网络的训练方法及装置
US18/306,144 US20230342606A1 (en) 2022-04-25 2023-04-24 Training method and apparatus for graph neural network

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210440602.3A CN114707644B (zh) 2022-04-25 2022-04-25 图神经网络的训练方法及装置

Publications (2)

Publication Number Publication Date
CN114707644A true CN114707644A (zh) 2022-07-05
CN114707644B CN114707644B (zh) 2024-09-06

Family

ID=82173699

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210440602.3A Active CN114707644B (zh) 2022-04-25 2022-04-25 图神经网络的训练方法及装置

Country Status (2)

Country Link
US (1) US20230342606A1 (zh)
CN (1) CN114707644B (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115545172A (zh) * 2022-11-29 2022-12-30 支付宝(杭州)信息技术有限公司 兼顾隐私保护和公平性的图神经网络的训练方法及装置
CN116896510B (zh) * 2023-02-09 2024-04-26 兰州大学 一种面向二分网络的基于奇数长度路径的链路预测方法
WO2024120166A1 (zh) * 2022-12-08 2024-06-13 马上消费金融股份有限公司 数据处理方法、类别识别方法及计算机设备

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112766500A (zh) * 2021-02-07 2021-05-07 支付宝(杭州)信息技术有限公司 图神经网络的训练方法及装置
CN112906873A (zh) * 2021-03-26 2021-06-04 北京邮电大学 一种图神经网络训练方法、装置、电子设备及存储介质
WO2021204763A1 (en) * 2020-04-07 2021-10-14 Koninklijke Philips N.V. Training a convolutional neural network
US11227190B1 (en) * 2021-06-29 2022-01-18 Alipay (Hangzhou) Information Technology Co., Ltd. Graph neural network training methods and systems
US20220083840A1 (en) * 2020-09-11 2022-03-17 Google Llc Self-training technique for generating neural network models

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021204763A1 (en) * 2020-04-07 2021-10-14 Koninklijke Philips N.V. Training a convolutional neural network
US20220083840A1 (en) * 2020-09-11 2022-03-17 Google Llc Self-training technique for generating neural network models
CN112766500A (zh) * 2021-02-07 2021-05-07 支付宝(杭州)信息技术有限公司 图神经网络的训练方法及装置
CN112906873A (zh) * 2021-03-26 2021-06-04 北京邮电大学 一种图神经网络训练方法、装置、电子设备及存储介质
US11227190B1 (en) * 2021-06-29 2022-01-18 Alipay (Hangzhou) Information Technology Co., Ltd. Graph neural network training methods and systems

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
张玮桐: "基于图表示和标签传播的复杂网络社区检测及其应用", 中国博士学位论文全文数据库 基础科学辑, no. 04, 15 April 2022 (2022-04-15) *

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115545172A (zh) * 2022-11-29 2022-12-30 支付宝(杭州)信息技术有限公司 兼顾隐私保护和公平性的图神经网络的训练方法及装置
CN115545172B (zh) * 2022-11-29 2023-02-07 支付宝(杭州)信息技术有限公司 兼顾隐私保护和公平性的图神经网络的训练方法及装置
WO2024120166A1 (zh) * 2022-12-08 2024-06-13 马上消费金融股份有限公司 数据处理方法、类别识别方法及计算机设备
CN116896510B (zh) * 2023-02-09 2024-04-26 兰州大学 一种面向二分网络的基于奇数长度路径的链路预测方法

Also Published As

Publication number Publication date
CN114707644B (zh) 2024-09-06
US20230342606A1 (en) 2023-10-26

Similar Documents

Publication Publication Date Title
US12072998B2 (en) Differentially private processing and database storage
Solus et al. Consistency guarantees for greedy permutation-based causal inference algorithms
JP7169369B2 (ja) 機械学習アルゴリズムのためのデータを生成する方法、システム
CN114707644B (zh) 图神经网络的训练方法及装置
CN112541575B (zh) 图神经网络的训练方法及装置
US9269055B2 (en) Data classifier using proximity graphs, edge weights, and propagation labels
US20230049817A1 (en) Performance-adaptive sampling strategy towards fast and accurate graph neural networks
CN110019790A (zh) 文本识别、文本监控、数据对象识别、数据处理方法
CN112988840A (zh) 一种时间序列预测方法、装置、设备和存储介质
US20220327394A1 (en) Learning support apparatus, learning support methods, and computer-readable recording medium
CN104077765A (zh) 图像分割装置、图像分割方法和程序
CN110889493A (zh) 针对关系网络添加扰动的方法及装置
CN114037518A (zh) 风险预测模型的构建方法、装置、电子设备和存储介质
CN112528109B (zh) 一种数据分类方法、装置、设备及存储介质
Almomani et al. Selecting a good stochastic system for the large number of alternatives
CN110717037A (zh) 对用户分类的方法和装置
Hou et al. Three-step risk inference in insurance ratemaking
CN110852080B (zh) 订单地址的识别方法、系统、设备和存储介质
CN113177596B (zh) 一种区块链地址分类方法和装置
CN113590721B (zh) 一种区块链地址分类方法和装置
CN114547448B (zh) 数据处理、模型训练方法、装置、设备、存储介质及程序
CN116089722B (zh) 基于图产出标签的实现方法、装置、计算设备和存储介质
CN116824305B (zh) 应用于云计算的生态环境监测数据处理方法及系统
CN110705642B (zh) 分类模型、方法、装置、电子设备及存储介质
Zhang et al. S2NMF: Information Self‐Enhancement Self‐Supervised Nonnegative Matrix Factorization for Recommendation

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant