CN112766500A - 图神经网络的训练方法及装置 - Google Patents

图神经网络的训练方法及装置 Download PDF

Info

Publication number
CN112766500A
CN112766500A CN202110177564.2A CN202110177564A CN112766500A CN 112766500 A CN112766500 A CN 112766500A CN 202110177564 A CN202110177564 A CN 202110177564A CN 112766500 A CN112766500 A CN 112766500A
Authority
CN
China
Prior art keywords
graph
matrix
node
neural network
network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110177564.2A
Other languages
English (en)
Other versions
CN112766500B (zh
Inventor
李群伟
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Alipay Hangzhou Information Technology Co Ltd
Original Assignee
Alipay Hangzhou Information Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Alipay Hangzhou Information Technology Co Ltd filed Critical Alipay Hangzhou Information Technology Co Ltd
Priority to CN202110177564.2A priority Critical patent/CN112766500B/zh
Publication of CN112766500A publication Critical patent/CN112766500A/zh
Application granted granted Critical
Publication of CN112766500B publication Critical patent/CN112766500B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本说明书实施例提供一种图神经网络的训练方法。该方法包括:先获取关系网络图,其中包括对应多个业务对象的多个对象节点;接着,针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,该多个对象节点对应的多个融合特征形成融合特征矩阵;利用所述图神经网络对所述关系网络图进行图嵌入处理,得到该多个对象节点对应的多个嵌入向量,该图神经网络中包括激活函数,并基于该多个嵌入向量,确定多个预测结果;并且,确定该融合特征矩阵经过该激活函数处理前后的乘积矩阵;基于该乘积矩阵、多个预测结果和业务标签,确定该图神经网络中参数的训练梯度,进而基于该训练梯度,更新该图神经网络中的参数。

Description

图神经网络的训练方法及装置
技术领域
本说明书实施例涉及计算机技术领域,尤其涉及一种图神经网络的训练方法及装置。
背景技术
关系网络图是对现实世界中实体之间的关系的描述,目前广泛地应用于各种计算机信息处理中。一般地,关系网络图包含节点集合和边集合,节点表示现实世界中的实体,边表示现实世界中实体之间的联系。例如,在社交网络中,人就是实体,人和人之间的关系或联系就是边。
在很多情况下,希望对关系网络图中的节点、边等的拓扑特性进行分析,从中提取出有效信息,实现这类过程的计算方法称为图计算。典型地,希望将关系网络图中的每个节点(实体)用相同维度的向量来表示,也就是生成针对每个节点的节点向量。如此,生成的节点向量可以应用于计算节点和节点之间的相似度,发现图中的社团结构,预测未来可能形成的边联系,以及对图进行可视化等。
节点向量的生成方法已成为图计算的基础算法。根据一种方案,可以利用图神经网络(GraphNeural Networks,简称GNN),生成关系网络图中节点的节点向量。然而,在图神经网络的训练阶段,因通常需要训练数百万个参数,以及目标函数具有高度非凸的特性,常规训练需要消耗大量的计算资源,并且,难以保证训练的收敛性和训练出的图神经网络的嵌入效果。
因此,需要一种改进的方案,可以减小图神经网络的训练消耗,加速收敛,提升图神经网络的嵌入表达性能。
发明内容
采用本说明书实施例描述的图神经网络的训练方法及装置,在训练过程中,避开复杂而耗时的对于激活函数的梯度求解,而使用一种代替该梯度的数值,该数据可以由训练数据直接算得,从而有效减少计算次数和计算耗时,使得图神经网络中的参数呈现可控的线性收敛趋势,快速得到训练好的、嵌入表征性能优异的图神经网络。
根据第一方面,提供一种图神经网络的训练方法,包括:获取关系网络图,其中包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中对象节点携带业务标签;针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵;利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数;基于所述多个嵌入向量,确定多个预测结果;确定所述融合特征矩阵经过所述激活函数处理前后的乘积矩阵;基于所述乘积矩阵、多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度;基于所述训练梯度,更新所述图神经网络中的参数。
在一个实施例中,所述多个业务对象涉及以下中的至少一种:用户、商品、商户、事件。
在一个实施例中,针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,包括:针对所述各个对象节点,对其节点特征与其邻居节点的节点特征进行平均处理,得到所述融合特征。
在一个实施例中,针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,包括:获取所述多个对象节点对应的节点特征矩阵、度矩阵和邻接矩阵;基于所述度矩阵、邻接矩阵和节点特征矩阵进行相乘处理,得到所述融合特征矩阵。
在一个实施例中,所述训练梯度所对应的损失函数基于多个预测结果与多个业务标签之间的差值向量而设定;其中,基于所述乘积矩阵、多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度,包括:基于所述乘积矩阵、融合特征矩阵以及差值向量进行相乘处理,得到所述训练梯度。
在一个实施例中,基于所述多个嵌入向量,确定所述多个对象节点对应的多个预测结果,包括:将所述多个嵌入向量分别输入预测网络,得到所述多个预测结果;在得到所述多个预测结果之后,所述方法还包括:基于所述多个预测结果和业务标签,更新所述预测网络中的参数。
在一个具体的实施例中,所述训练梯度所对应的损失函数基于所述多个预测结果与多个业务标签之间的差值向量而设定,所述预测网络利用参数向量对输入的嵌入向量进行线性变换处理;其中,基于所述乘积矩阵、多个预测结果和多个业务标签,确定所述图神经网络中参数的训练梯度,包括:基于所述乘积矩阵、融合特征矩阵、差值向量和所述参数向量进行相乘处理,得到所述训练梯度。
在一个更具体的实施例中,基于所述多个预测结果和业务标签,更新所述预测网络中的参数,包括:基于所述融合特征矩阵、所述差值向量,以及所述图神经网络的当前参数,确定所述参数向量对应的梯度向量,并根据所述梯度向量更新所述预测网络。
在一个实施例中,基于所述训练梯度,更新所述图神经网络中的参数,包括:确定预设学习率与训练梯度之间的乘积;将所述图神经网络中的参数更新为其与所述乘积之差。
根据第二方面,提供一种图神经网络的训练方法,包括:获取多个关系网络图,其中任一的第一关系网络图包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中关系网络图携带业务标签;针对所述第一关系网络图中的各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵;利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数;利用所述第一关系网络图对应的第一权重向量,对所述多个嵌入向量进行加权处理,得到所述第一关系网络图对应的图表征向量;基于多个关系网络图对应的多个图表征向量,确定多个预测结果;针对各个关系网络图,基于其对应的融合矩阵和权重向量,及其融合矩阵经过所述激活函数处理后得到的矩阵进行乘积处理,得到所述多个关系网络图对应的多个乘积矩阵;基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度;基于所述训练梯度,更新所述图神经网络中的参数。
在一个实施例中,所述多个业务对象涉及以下中的至少一种:用户、商品、商户、事件、化学元素。
在一个实施例中,所述第一权重向量中各个向量元素的值相等,或者,所述第一权重向量中包含单个非零元素。
在一个实施例中,所述训练梯度所对应的损失函数基于预测结果与业务标签之间的差值而设定;其中,基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度,包括:针对各个关系网络图,对其所对应的融合特征矩阵、权重向量,以及预测结果与业务标签之间的差值进行相乘处理,得到相乘矩阵;对所述多个关系网络图对应的多个相乘矩阵进行求和,得到求和矩阵;基于所述平均矩阵以及求和矩阵进行相乘处理,得到所述训练梯度。
在一个实施例中,基于多个关系网络图对应的多个图表征向量,确定多个预测结果,包括:将所述多个图表征向量分别输入预测网络,得到所述多个预测结果;在得到所述多个预测结果之后,所述方法还包括:基于所述多个预测结果和业务标签,更新所述预测网络中的参数。
在一个具体的实施例中,所述训练梯度所对应的损失函数基于预测结果与业务标签之间的差值而设定,所述预测网络利用参数向量对输入的图表征向量进行线性变换处理;其中,基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度,包括:针对各个关系网络图,对其所对应的融合特征矩阵、权重向量,以及预测结果与业务标签之间的差值进行相乘处理,得到相乘矩阵;对所述多个关系网络图对应的多个相乘矩阵进行求和,得到求和矩阵;基于所述平均矩阵、求和矩阵以及参数向量进行相乘处理,得到所述训练梯度。
在一个更具体的实施例中,基于所述多个预测结果和业务标签,更新所述预测网络中的参数,包括:基于所述多个融合特征矩阵、多个权重向量、多个预测结果与多个业务标签之间的多个差值,以及所述图神经网络的当前参数,确定所述参数向量对应的梯度向量,并根据所述梯度向量更新所述预测网络。
根据第三方面,提供一种图神经网络的训练装置,包括:图谱获取单元,配置为获取关系网络图,其中包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中对象节点携带业务标签;特征融合单元,配置为针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵;图嵌入单元,配置为利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数;预测单元,配置为基于所述多个嵌入向量,确定多个预测结果;矩阵确定单元,配置为确定所述融合特征矩阵经过所述激活函数处理前后的乘积矩阵;梯度确定单元,配置为基于所述乘积矩阵、多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度;参数更新单元,配置为基于所述训练梯度,更新所述图神经网络中的参数。
根据第四方面,提供一种图神经网络的训练装置,包括:图谱获取单元,配置为获取多个关系网络图,其中任一的第一关系网络图包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中关系网络图携带业务标签;特征融合单元,配置为针对所述第一关系网络图中的各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵;图嵌入单元,配置为利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数;图表征单元,配置为利用所述第一关系网络图对应的第一权重向量,对所述多个嵌入向量进行加权处理,得到所述第一关系网络图对应的图表征向量;预测单元,配置为基于多个关系网络图对应的多个图表征向量,确定多个预测结果;矩阵确定单元,配置为针对各个关系网络图,基于其对应的融合矩阵和权重向量,及其融合矩阵经过所述激活函数处理后得到的矩阵进行乘积处理,得到所述多个关系网络图对应的多个乘积矩阵;梯度确定单元,配置为基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度;参数更新单元,配置为基于所述训练梯度,更新所述图神经网络中的参数。
根据第五方面,提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行第一方面或第二方面所描述的方法。
根据第六方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现第一方面或第二方面所描述的方法。
综上,采用本说明书实施例披露的图神经网络的训练方法及装置,在训练过程中,避开复杂而耗时的对于激活函数的梯度求解,而使用一种代替该梯度的数值,该数据可以由训练数据直接算得,从而有效减少计算次数和计算耗时,使得图神经网络中的参数呈现可控的线性收敛趋势,快速得到训练好的、嵌入表征性能优异的图神经网络。
附图说明
为了更清楚地说明本说明书披露的多个实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本说明书披露的多个实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1示出根据一个实施例的图神经网络的训练方法流程图;
图2示出根据另一个实施例的图神经网络的训练方法流程图;
图3示出根据一个实施例的图神经网络的训练装置结构图;
图4示出根据另一个实施例的图神经网络的训练装置结构图。
具体实施方式
下面结合附图,对本说明书披露的多个实施例进行描述。
如前所述,目前,图神经网络的训练遇到瓶颈。具体而言,图神经网络是可以在社交网络或其它基于图形拓扑数据上运行的深度学习架构,它是一种基于图拓扑结构的广义神经网络。图神经网络一般是将底层关系网络图作为计算图,并通过在整张图上传递、转换和聚合节点特征信息,从而学习神经网络基元以生成单节点嵌入向量。生成的节点嵌入向量可作为预测层的输入,并用于节点分类或预测节点之间的连接,完整的模型可以通过端到端的方式训练。
因深度神经网络具有高度非凸性等原因,给图神经网络的训练收敛性分析及训练算法的改进,带来了巨大的挑战。比如说,在具有无界分布和平方损失的不可知论环境中,以单元范数隐藏权重向量,学习单个线性整流函数就非常困难。
进一步,发明人发现,在传统的随机梯度下降等训练方式中,利用反向传播梯度去更新网络参数的过程是一个公认的黑盒子,其训练结果不可控且不可预知,造成这一过程是个黑盒的主要原因在于,需要对激活函数进行梯度求解,而对激活函数进行梯度求解复杂而耗时。
基于此,发明人提出一种白盒式的训练算法,该算法避开对激活函数的梯度求解,而使用了一种代替该梯度的数值,该数值可以由训练数据直接算得。而且,可以通过理论证明此种白盒式算法的有效性和训练结果的可控性,开创了图神经网络领域的先河。
为便于理解,下面对图神经网络的训练实施流程进行示例性描述。首先需说明,图神经网络依托的图结构包括两种,一种是有多张关系网络图,图中的节点包含特征信息,每张图的节点信息需要做特征聚合来表征相应的图的特征,训练数据中的标签信息对应于每张图;另一种是单张图,图中存在多个节点,图中的节点包含特征数据,且包含标签信息。
针对某张关系网络图将其中每个节点自身和其邻居节点的特征做融合,将融合更新后的特征作为该每个节点的新特征;然后,可以将每个节点的新特征输入一个共同的深度神经网络中,其输出作为相应节点的嵌入表达,或者,将每个节点的原始特征输入该深度神经网络中,其输出作为相应节点的嵌入表达。
之后,对于上述标签信息对应每张图的情况,基于每张图中所有节点的嵌入表达做特征聚合,并用聚合得到的聚合特征作为该张图的特征表述。对于上述标签信息在图中节点上的情况,则不做特征聚合,此时,节点的嵌入表达即为相应节点的特征表述。
接下来,将图或节点的特征表述输入到预测处理层中,其输出作为图或节点的标签预测。在得到预测结果后,基于预测结果,相应图或节点的标签,以及基于相应训练数据计算出的用于代替激活函数梯度的数值,确定图神经网络,或图神经网络和预测网络的训练梯度,进行网络参数的更新。
下面描述上述白盒算法的实现步骤,为清楚描述,对标签信息对应节点和对应整张图的情况进行分开描述。图1示出根据一个实施例的图神经网络的训练方法流程图,所述方法的执行主体可以为任何计算、处理能力的装置、平台、服务器或设备集群。如图1所示,所述方法包括以下步骤:
步骤S101,获取关系网络图,其中包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中对象节点携带业务标签;步骤S103,针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,上述多个对象节点对应的多个融合特征形成融合特征矩阵;步骤S105,利用上述图神经网络对上述关系网络图进行图嵌入处理,得到上述多个对象节点对应的多个嵌入向量;上述图神经网络中包括激活函数;步骤S107,基于上述多个嵌入向量,确定多个预测结果;步骤S109,确定上述融合特征矩阵经过上述激活函数处理前后的乘积矩阵;步骤S111,基于上述乘积矩阵、多个预测结果和业务标签,确定上述图神经网络中参数的训练梯度;步骤S113,基于上述训练梯度,更新上述图神经网络中的参数。
对以上步骤的展开介绍如下:
首先,在步骤S101,获取关系网络图,其中包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中对象节点携带业务标签。
上述多个业务对象可以涉及以下中的至少一种:用户、商户、商品、终端设备、事件。
在一个实施例中,上述多个业务对象中包括多个用户,上述关联关系可以包括以下中的至少一种:社交关系、设备关系、交易关系和内容交互关系。在一个具体的实施例中,若两个用户在社交平台中互相关注或者互为好友,则认为对应的两个用户节点之间存在社交关系。在一个具体的实施例中,若两个用户使用过同一个终端设备,则认为对应的节点之间存在设备关系。在一个具体的实施例中,若一个用户曾向另一个用户转账或发起收款,则认为二者间存在交易关系。在一个具体的实施例中,若两个用户互相发送过内容,则认为二者间存在内容交互关系。在一个例子中,其中内容可以是文本、链接、图片(如动态表情)、或视频等。在另一个实施例中,上述多个业务对象中可以包括多个商户,上述关联关系可以包括合作关系或上下游供应关系。
在又一个实施例中,上述关系网络图可以是用户-商品二部图,相应地,上述多个业务对象中可以包括多个用户和多个商品,并且,连接边只存在于用户节点和商品节点之间,用户节点之间和商品节点之间不存在连接边,进一步,上述关联关系可以包括点击关系、或购买关系或评价关系。在一个具体的实施例中,若某个用户曾点击过针对某件商品的广告信息,则认为二者之间存在点击关系。在另一个具体的实施例中,若某个用户曾今购买过某件商品,则认为二者之间存在购买关系。
关于对象节点携带的业务标签。在一个实施例中,在节点分类场景下,多个对象节点中的各个对象节点各自携带类别标签。在一个具体的实施例中,对象节点属于用户节点,相应的类别标签可以是用户风险标签,如高风险、中风险和低风险等。在另一个具体的实施例中,对象节点属于商品节点,相应的类别标签可以是商品热度标签,如爆款商品、热销商品和冷门商品等。
在另一个实施例中,在链路预测场景下,两个对象节点可以共同对应一个业务标签,在一个具体的实施例中,用户节点和商品节点对应一个行为标签,该行为标签指示该用户节点是否对商品节点做出预设行为,如购买、评论或点击。在另一个具体的实施例中,两个商户节点对应一个供应关系标签,该供应关系标签指示两个商户之间是否存在供应关系。
以上,对获取的关系网络图,以及其中多个对象节点对应的业务对象、业务标签,还有对象节点之间的关联关系进行介绍。
在获取到关系网络图后,在步骤S103,针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵。
在一个实施例中,本步骤可以包括:针对所述各个对象节点,对其自身和其邻居节点的节点特征进行平均处理,得到所述融合特征。在一个示例中,将上述多个对象节点所对应的节点特征矩阵记作H,将第i个节点的节点特征记作Hi,其邻居节点集记作Ni,相应,可以将第i个节点对应的融合特征表示为
Figure BDA0002941141440000111
其中|Ni|表示集合Ni中的元素个数。由此,经过多次平均处理,可以得到多个对象节点对应的多个融合特征,该多个融合特征堆叠后可得到融合特征矩阵。
在另一个实施例中,本步骤可以包括:获取多个对象节点对应的节点特征矩阵、度矩阵和邻接矩阵;基于所述度矩阵、邻接矩阵和节点特征矩阵进行相乘处理,得到所述融合特征矩阵。需说明,度矩阵是对角矩阵,对角上的元素为各个对象节点的度,任一节点的度表示和该节点相连接的连接边的数量。邻接矩阵表示节点间的连接关系,是n阶方阵(n为对象节点的数量),例如,在无向图中,两个节点间有连接,则对应矩阵元素取值为1,否则为0,而在有向图中,还需考虑连接边的方向性。在一个示例中,将度矩阵、邻接矩阵和节点特征矩阵依次记作D、A和H,相应可以将相乘得到的融合特征矩阵记作D-1AH,式中上标-1表示矩阵的逆,并且,融合特征矩阵D-1AH中包括各个对象节点的融合特征,与利用该节点的节点特征与其邻居节点的节点特征进行平均处理而得到的结果相同。
在还一个实施例中,本步骤可以包括:针对所述各个对象节点,基于其自身与其邻居节点的节点特征,确定其为其自身及其各个邻居节点分配的注意力权重,再基于注意力权重对相应的节点特征进行加权求和,得到对应节点的融合特征。在一个具体的实施例中,其中注意权重的计算可以包括:计算该对象节点与其自身以及其各个邻居节点之间的节点特征相似度,再将相似度进行归一化处理(如采用softmax函数,或者求占比的计算方式),得到对应的注意力权重。在另一个具体的实施例中,可以引入注意力打分网络,具体,将对象节点的节点特征,分别与其自身以及其各个邻居节点的节点特征进行拼接,并将多个拼接特征分别输入注意力打分网络,得到多个注意力分数,对该多个注意力分数进行归一化处理,对应得到多个注意力权重。
由上,可以得到多个对象节点对应的多个融合特征,以及多个融合特征组成的融合特征矩阵。基于此,在步骤S105,利用上述图神经网络对上述关系网络图进行图嵌入处理,得到上述多个对象节点对应的多个嵌入向量;上述图神经网络中包括激活函数。
在一个实施例中,上述图神经网络可以实现为深度神经网络DNN(Deep NeuralNetwork),图卷积神经网络GCN(Graph Convolutional Network),图注意力网络GAT(GraphAttention Network)。相应,该图神经网络可以基于上述多个融合特征,对关系网络图进行图嵌入处理,得到各个对象节点对应的各个节点嵌入向量。
对于上述图神经网络中包含的激活函数,在一个实施例中,图神经网络可以包括一个或多个网络层,其中任意数量的网络层中可以包括激活函数,其基于输入对应网络层的特征进行非线性变换处理。在一个具体的实施例中,对于其中任意两个网络层中包含的激活函数,可以相同,也可以不同。另一方面,在一个实施例中,激活函数包括ReLU,LeakyReLU,Sigmoid和Softplus等。在一个实施例中,激活函数可以实现为非平凡的递增函数,且1阶利普西茨连续(1-Lipschiz continuous)。
由上,利用图神经网络处理多个融合特征,可以得到多个对象节点对应的多个嵌入向量。在步骤S107,基于该多个嵌入向量,确定多个预测结果。需理解,该多个预测结果,与由关系网络图中对象节点携带的多个业务标签相对应。
在一个实施例中,在节点分类场景下,上述多个对象节点对应多个业务标签,相应,可以将该多个对象节点对应的多个嵌入向量输入预测网络中,得到多个预测结果,该多个预测结果与该多个业务标签相对应。
在另一个实施例中,在链路预测场景下,多组对象节点对应多个业务标签,相应,在一个具体的实施例中,可以将各组对象节点所对应的两个或以上嵌入向量进行拼接后,输入预测网络中,得到对应的预测结果,多组对象节点对应多个预测结果;在另一个具体的实施例中,可以计算各组对象节点所对应的两个嵌入向量的相似度,作为对应的预测结果,多组对象节点对应多个预测结果。
以上,可以基于多个嵌入向量确定多个预测结果。另一方面,在步骤S109,确定上述融合特征矩阵经过上述激活函数处理前后的乘积矩阵。在一个示例中,将融合矩阵记作c,激活函数记作σ,相应,可以利用下式(1)计算乘积矩阵ΞN,其中上标T表示矩阵的转置。
ΞN=cTσ(c) (1)
需说明,步骤S109在步骤S103之后执行即可,步骤S109相对步骤S105、S107的执行顺序不作限定。
以上,在确定乘积矩阵和多个预测结果后,在步骤S111,基于该乘积矩阵、多个预测结果和上述业务标签,确定图神经网络中参数的训练梯度。
需说明,对于图神经网络中包含激活函数的网络层,其对融合特征矩阵的处理可以表示为σ(cW),其中σ表示激活函数,c表示融合特征矩阵,W表示网络层中用于对c进行线性变换处理的参数矩阵。在传统的随机梯度下降法中,在计算W的梯度时,需求解激活函数的梯度项σ′(c),然而,就大多数激活函数而言,求解σ′(c)的计算量巨大,为了得到W和其他与图神经网络联合训练的神经网络的网络参数的真实梯度,需要对训练数据集进行较高数量级的遍历,而利用上述乘积矩阵ΞN,作为σ′(c)的近似替代项,可以直接根据训练数据计算得到,只需遍历一次训练数据集,并且,可以实现图神经网络的参数随着训练迭代次数的增加而线性收敛,这样的训练收敛过程和训练效果是可控的,并且经过理论论证,只需消耗相对常规方式小得多得计算量,即可训练出嵌入表征效果优异的图神经网络。
在一种实施方式中,可以基于预定的损失函数形式,对多个预测结果以及相对应的多个业务标签进行计算处理,得到训练损失,再基于训练损失通过求偏导运算,确定图神经网络参数的梯度计算表达式,进而将该梯度计算表达式中针对激活函数的梯度求解项替换成上述乘积矩阵,由此得到新的梯度计算表达式,用于确定实际训练过程中的训练梯度。在一个实施例中,其中损失函数形式可以是交叉熵损失函数、铰链损失函数、二阶范数的平方等。
在一个实施例中,训练梯度所对应的损失函数基于多个预测结果与多个业务标签之间的差值向量而设定,例如,损失函数形式可以是差值向量的二阶范数的平方,或者,可以是差值向量的平方或绝对值的立方等。相应地,图神经网络中参数训练梯度的确定可以包括:基于上述乘积矩阵、融合特征矩阵以及差值向量进行相乘处理,得到训练梯度。
进一步,在一个具体的实施例中,上述多个预测结果基于上述预测网络而确定,预测网络利用参数向量对输入的嵌入向量进行线性变换处理,相应地,训练梯度的确定可以包括:基于该参数向量、上述乘积矩阵、融合特征矩阵和差值向量进行相乘处理,得到上述训练梯度。在一个示例中,训练梯度的计算式如下:
Figure BDA0002941141440000141
在公式(2)中,
Figure BDA0002941141440000142
表示第t轮次的迭代训练中图神经网络参数W的训练梯度;n表示上述多个对象节点的节点总数;ΞN表示上述乘积矩阵,上标-1表示矩阵的逆;c表示融合特征矩阵;y表示多个业务标签组成的向量;
Figure BDA0002941141440000143
表示多个预测结果组成的向量;vt表示第t轮迭代预测网络中更新前的参数向量,上标T表示矩阵的转置。需要说明,其中
Figure BDA0002941141440000144
的确定过程可表示为下式(3),σ表示激活函数。
Figure BDA0002941141440000145
由上,可以得到图神经网络参数的训练梯度。另一方面,可以对预测网络和图神经网络进行联合训练,相应地,在一个实施例中,在得到多个预测结果后,上述方法还可以包括:基于多个预测结果和业务标签,更新所述预测网络中的参数。在一个具体的实施例中,所述预测网络利用参数向量对输入的嵌入向量进行线性变换处理,上述训练梯度所对应的损失函数基于所述多个预测结果与多个业务标签之间的差值向量而设定,相应,可以基于上述融合特征矩阵、差值向量,以及图神经网络的当前参数,确定预测网络中的参数向量对应的梯度向量,并根据该梯度向量更新预测网络。
在一个示例中,与上述公式(2)和(3)相关联的,可以采用下式(4),计算参数向量对应的梯度向量。
Figure BDA0002941141440000151
在公式(4)中,
Figure BDA0002941141440000152
表示第t轮次的迭代训练中预测网络中参数向量v的梯度向量;Wt表示t轮迭代图神经网络中更新前的参数;对其他符号的说明可以参见公式(2)和(3)中的相关说明。
如此,可以实现对预测网络中参数的更新。
以上,可以确定出图神经网络中参数的训练梯度,接着在步骤S113,基于该训练梯度,更新上述图神经网络中的参数。在一个实施例中,本步骤可以包括:确定预设学习率与训练梯度之间的乘积;将所述图神经网络中的参数更新为其与所述乘积之差。在一个具体的实施例中,其中预设学习率为超参,可以设定为0.1或0.2等。在一个示例中,可以用下式实现图神经网络参数的更新,其中α表示学习率。
Figure BDA0002941141440000153
在另一示例中,还可以用下式(6)和(7)实现图神经网络参数的更新。
Figure BDA0002941141440000154
Wt+1=Ut+1/‖Ut+12 (7)
由上,而可以实现对图神经网络中参数的更新。
综上,采用本说明书实施例披露的图神经网络的训练方法,在训练过程中,避开复杂而耗时的对于激活函数的梯度求解,而使用一种代替该梯度的数值,该数据可以由训练数据直接算得,从而有效减少计算次数和计算耗时,使得图神经网络中的参数呈现可控的线性收敛趋势,快速得到训练好的、嵌入表征性能优异的图神经网络。
以上,主要对节点携带标签信息情况下图神经网络的训练方法进行介绍,下面,着重对标签信息对应整张图的情况下的图神经网路训练进行描述。
图2示出根据另一个实施例的图神经网络的训练方法流程图,所述方法的执行主体可以为任何计算、处理能力的装置、平台、服务器或设备集群。
如图2所示,所述方法包括以下步骤:
步骤S202,获取多个关系网络图,其中任一的第一关系网络图包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中关系网络图携带业务标签;步骤S204,针对上述第一关系网络图中的各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,上述多个对象节点对应的多个融合特征形成融合特征矩阵;步骤S206,利用上述图神经网络对上述关系网络图进行图嵌入处理,得到上述多个对象节点对应的多个嵌入向量;上述图神经网络中包括激活函数;步骤S208,利用上述第一关系网络图对应的第一权重向量,对上述多个嵌入向量进行加权处理,得到上述第一关系网络图对应的图表征向量;步骤S210,基于多个关系网络图对应的多个图表征向量,确定多个预测结果;步骤S212,针对各个关系网络图,基于其对应的融合矩阵和权重向量,及其融合矩阵经过上述激活函数处理后得到的矩阵进行乘积处理,得到上述多个关系网络图对应的多个乘积矩阵;步骤S214,基于上述多个乘积矩阵的平均矩阵,上述多个预测结果和业务标签,确定上述图神经网络中参数的训练梯度;步骤S216,基于上述训练梯度,更新上述图神经网络中的参数。
对以上步骤的展开介绍如下:
首先,在步骤S202,获取多个关系网络图,其中任一的第一关系网络图包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中关系网络图携带业务标签。
上述多个业务对象可以涉及以下中的至少一种:用户、商户、商品、终端设备、事件、化学元素等。在一个实施例中,上述多个业务对象中包括多个化学元素,上述关联关系可以包括化学键,若两个化学元素之间存在化学键,则在对应节点之间建立连接边。需说明,对关系网络图的描述还可以参见前述实施例中的相关描述。
对于关系网络图携带的业务标签。在一个实施例中,在图分类场景下,多个关系网络图中各个关系网络图各自携带类别标签。在一个具体的实施例中,关系网络图对应化学物质的分子式,相应地,图类别标签可以为烃类、烷类等。在另一个具体的实施例中,关系网络图对应机械装置,其中各个对象节点为机械组件,连接边指示对应组件的连接关系,相应地,图类别标签可以包括绞肉机、榨汁机等。在另一个实施例中,在关系网络图间的关系预测场景下,一组关系网络图可以共同对应一个业务标签。在一个具体的实施例中,可以是两个关系网络图对应一个业务标签,该业务标签指示二者是否为同类事物。
以上,对获取的多个关系网络图进行介绍。
在获取到多个关系网络图后,在步骤S204,针对上述第一关系网络图中的各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,上述多个对象节点对应的多个融合特征形成融合特征矩阵。需说明,对步骤S204的描述可以参见前述对步骤S103的描述,区别在于,通过针对各个关系网络图执行步骤S204,可以得到多个关系网络图对应的多个融合特征矩阵。
进一步,在步骤S206,利用上述图神经网络对上述关系网络图进行图嵌入处理,得到上述多个对象节点对应的多个嵌入向量;上述图神经网络中包括激活函数。需说明,对步骤S206的描述可以参见前述对步骤S105的描述,区别在于,通过针对各个关系网络图执行步骤S206,可以得到多个关系网络图中各个关系网络图对应的多个节点嵌入向量。
在得到任一的第一关系网络所对应的多个嵌入向量后,在步骤S208,利用上述第一关系网络图对应的第一权重向量,对该多个嵌入向量进行加权处理,得到第一关系网络图对应的图表征向量。需说明,不同关系网络图所对应的权重向量可能相同,也可能不同。
在一个实施例中,第一权重向量中各个向量元素的值相等,且元素和值为预定数值(如1)。在另一个实施例中,第一权重向量中包含单个非零元素,也就是说,只有一个元素的值为非零(如1),其余元素的值均为0。对于非零元素的元素位置的选取,在一个例子中,可以是随机选取的,在另一个例子中,可以将邻居节点数量最多的节点所对应的元素设为预定的非零数值,在还一个例子中,可以将多个嵌入向量中方差最大的嵌入向量所对应的权重元素设定为非零元素。在还一个实施例中,第一权重向量中的向量元素为学习参数。
由上,通过对各个关系网络图对应的节点嵌入向量进行加权聚合,可以得到对应的图表征向量。在步骤S210,基于多个关系网络图对应的多个图表征向量,确定多个预测结果。需理解,该多个预测结果,与由关系网络图携带的多个业务标签相对应。
在一个实施例中,在图分类场景下,上述多个关系网络图对应多个业务标签,相应,可以将该多个关系网络图对应的多个图表征向量输入预测网络中,得到多个预测结果,该多个预测结果与该多个业务标签相对应。
在另一个实施例中,在图间关系的预测场景下,多组关系网络图对应多个业务标签,相应,在一个具体的实施例中,可以将各组关系网络图所对应的两个或以上图表征向量进行拼接或加和后,输入预测网络中,得到对应的预测结构,多组关系网络图对应多个预测结果;在另一个具体的实施例中,可以计算各组关系网络图所对应的两个图表征向量的相似度,作为对应的预测结果,多组关系网络图对应多个预测结果。
以上,可以基于多个图表征向量确定多个预测结果。另一方面,在步骤S212,针对各个关系网络图,基于其对应的融合矩阵和权重向量,及其融合矩阵经过上述激活函数处理后得到的矩阵进行乘积处理,得到上述多个关系网络图对应的多个乘积矩阵。在一个示例中,将第j个关系网络图所对应的融合矩阵记作cj,激活函数记作σ,权重向量记作aj,多个关系网络图的数量记作n,相应,可以利用下式(1)计算多个乘积矩阵的平均矩阵。
Figure BDA0002941141440000191
需说明,步骤S212在步骤S204之后执行即可,步骤S212相对步骤S206、S208及S210的执行顺序不作限定。
以上,在确定多个乘积矩阵的平均矩阵和多个预测结果后,在步骤S214,基于该平均矩阵,上述多个预测结果和业务标签,确定上述图神经网络中参数的训练梯度。
需说明,对于图神经网络中包含激活函数的网络层,其对融合特征矩阵的处理可以表示为σ(cW),其中σ表示激活函数,c表示融合特征矩阵,W表示网络层中用于对c进行线性变换处理的参数矩阵。在传统的随机梯度下降法中,在计算W的梯度时,需求解激活函数的梯度项σ′(c),然而,就大多数激活函数而言,求解σ′(c)的计算量巨大,为了得到W和其他与图神经网络联合训练的神经网络的网络参数的真实梯度,需要对训练数据集进行较高数量级的遍历,而利用上述乘积矩阵ΞN,作为σ′(c)的近似替代项,可以直接根据训练数据计算得到,只需遍历一次训练数据集,并且,可以实现图神经网络的参数随着训练迭代次数的增加而线性收敛,这样的训练收敛过程和训练效果是可控的,并且经过理论论证,只需消耗相对常规方式小得多得计算量,即可训练出嵌入表征效果优异的图神经网络。
在一种实施方式中,可以基于预定的损失函数形式,对多个预测结果以及相对应的多个业务标签进行计算处理,得到训练损失,再基于训练损失通过求偏导运算,确定图神经网络参数的梯度计算表达式,进而将该梯度计算表达式中针对激活函数的梯度求解项替换成上述乘积矩阵,由此得到新的梯度计算表达式,用于确定实际训练过程中的训练梯度。在一个实施例中,其中损失函数形式可以是交叉熵损失函数、铰链损失函数、二阶范数的平方等。
在一个实施例中,训练梯度所对应的损失函数基于预测结果与业务标签之间的差值而设定,例如,损失函数形式可以是差值的平方或者绝对值的立方等。相应地,图神经网络中参数训练梯度的确定可以包括:针对各个关系网络图,对其所对应的融合特征矩阵、权重向量,以及预测结果与业务标签之间的差值进行相乘处理,得到相乘矩阵;对所述多个关系网络图对应的多个相乘矩阵进行求和,得到求和矩阵;基于所述平均矩阵以及求和矩阵进行相乘处理,得到所述训练梯度。
进一步,在一个具体的实施例中,上述多个预测结果基于上述预测网络而确定,预测网络利用参数向量对输入的嵌入向量进行线性变换处理,相应地,训练梯度的确定可以包括:基于所述平均矩阵、求和矩阵以及参数向量进行相乘处理,得到所述训练梯度。在一个示例中,训练梯度的计算式如下:
Figure BDA0002941141440000201
在公式(9)中,
Figure BDA0002941141440000202
表示第t轮次的迭代训练中图神经网络参数W的训练梯度;n表示上述多个关系网络图的图总数;ΞG表示上述平均矩阵,上标-1表示矩阵的逆;cj表示第j个关系网络图对应的融合特征矩阵;yj表示第j个关系网络图对应的业务标签,
Figure BDA0002941141440000203
表示针对第j个关系网络图的预测结果;vt表示第t轮迭代预测网络中更新前的参数向量,上标T表示矩阵的转置。需要说明,其中
Figure BDA0002941141440000204
的确定过程可表示为下式(10),σ表示激活函数。
Figure BDA0002941141440000205
由上,可以得到图神经网络参数的训练梯度。另一方面,可以对预测网络和图神经网络进行联合训练,相应地,在一个实施例中,在得到多个预测结果后,上述方法还可以包括:基于多个预测结果和业务标签,更新所述预测网络中的参数。在一个具体的实施例中,所述预测网络利用参数向量对输入的嵌入向量进行线性变换处理,上述训练梯度所对应的损失函数基于所述预测结果与业务标签之间的差值而设定,相应,可以基于上述多个融合特征矩阵、多个权重向量、多个预测结果与多个业务标签之间的多个差值,以及所述图神经网络的当前参数,确定参数向量对应的梯度向量,并根据该梯度向量更新预测网络。
在一个示例中,与上述公式(9)和(10)相关联的,可以采用下式(11),计算参数向量对应的梯度向量。
Figure BDA0002941141440000211
在公式(11)中,
Figure BDA0002941141440000212
表示第t轮次的迭代训练中预测网络中参数向量v的梯度向量;Wt表示t轮迭代图神经网络中更新前的参数;对其他符号的说明可以参见公式(9)和(10)中的相关说明。
如此,可以实现对预测网络中参数的更新。
以上,可以确定出图神经网络中参数的训练梯度,接着在步骤S216,基于该训练梯度,更新上述图神经网络中的参数。需说明,对步骤S216的描述,可以参见对前述步骤S113的描述,不作赘述。
综上,采用本说明书实施例披露的图神经网络的训练方法,在训练过程中,避开复杂而耗时的对于激活函数的梯度求解,而使用一种代替该梯度的数值,该数据可以由训练数据直接算得,从而有效减少计算次数和计算耗时,使得图神经网络中的参数呈现可控的线性收敛趋势,快速得到训练好的、嵌入表征性能优异的图神经网络。
与上训练方法相对应的,本说明书实施例还披露训练装置。具体如下:
图3示出根据一个实施例的图神经网络的训练装置结构图。如图3所示,所述装置300包括以下单元:
图谱获取单元301,配置为获取关系网络图,其中包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中对象节点携带业务标签。特征融合单元303,配置为针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵。图嵌入单元305,配置为利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数。预测单元307,配置为基于所述多个嵌入向量,确定多个预测结果。矩阵确定单元309,配置为确定所述融合特征矩阵经过所述激活函数处理前后的乘积矩阵。梯度确定单元311,配置为基于所述乘积矩阵、多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度。参数更新单元313,配置为基于所述训练梯度,更新所述图神经网络中的参数。
在一个实施例中,所述多个业务对象涉及以下中的至少一种:用户、商品、商户、事件。
在一个实施例中,特征融合单元303具体配置为:针对各个对象节点,对其节点特征与其邻居节点的节点特征进行平均处理,得到所述融合特征。
在一个实施例中,特征融合单元303具体配置为:获取所述多个对象节点对应的节点特征矩阵、度矩阵和邻接矩阵;基于所述度矩阵、邻接矩阵和节点特征矩阵进行相乘处理,得到所述融合特征矩阵。
在一个实施例中,所述训练梯度所对应的损失函数基于所述多个预测结果与多个业务标签之间的差值向量而设定;其中梯度确定单元311具体配置为:基于所述乘积矩阵、融合特征矩阵以及差值向量进行相乘处理,得到所述训练梯度。
在一个实施例中,预测单元307具体配置为:将所述多个嵌入向量分别输入预测网络,得到所述多个预测结果;所述参数更新单元313还配置为:基于所述多个预测结果和业务标签,更新所述预测网络中的参数。
在一个具体的实施例中,所述训练梯度所对应的损失函数基于所述多个预测结果与多个业务标签之间的差值向量而设定,所述预测网络利用参数向量对输入的嵌入向量进行线性变换处理;其中梯度确定单元311具体配置为:基于所述乘积矩阵、融合特征矩阵、差值向量和所述参数向量进行相乘处理,得到所述训练梯度。
在一个更具体的实施例中,参数更新单元313进一步配置为:基于所述融合特征矩阵、所述差值向量,以及所述图神经网络的当前参数,确定所述参数向量对应的梯度向量,并根据所述梯度向量更新所述预测网络。
在一个实施例中,参数更新单元313具体配置为:确定预设学习率与训练梯度之间的乘积;将所述图神经网络中的参数更新为其与所述乘积之差。
综上,采用本说明书实施例披露的图神经网络的训练装置,在训练过程中,避开复杂而耗时的对于激活函数的梯度求解,而使用一种代替该梯度的数值,该数据可以由训练数据直接算得,从而有效减少计算次数和计算耗时,使得图神经网络中的参数呈现可控的线性收敛趋势,快速得到训练好的、嵌入表征性能优异的图神经网络。
图4示出根据另一个实施例的图神经网络的训练装置结构图。如图4所示,所述装置400包括以下单元:
图谱获取单元402,配置为获取多个关系网络图,其中任一的第一关系网络图包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中关系网络图携带业务标签。特征融合单元404,配置为针对所述第一关系网络图中的各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵。图嵌入单元406,配置为利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数。图表征单元408,配置为利用所述第一关系网络图对应的第一权重向量,对所述多个嵌入向量进行加权处理,得到所述第一关系网络图对应的图表征向量。预测单元410,配置为基于多个关系网络图对应的多个图表征向量,确定多个预测结果。矩阵确定单元412,配置为针对各个关系网络图,基于其对应的融合矩阵和权重向量,及其融合矩阵经过所述激活函数处理后得到的矩阵进行乘积处理,得到所述多个关系网络图对应的多个乘积矩阵。梯度确定单元414,配置为基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度。参数更新单元416,配置为基于所述训练梯度,更新所述图神经网络中的参数。
在一个实施例中,所述多个业务对象涉及以下中的至少一种:用户、商品、商户、事件、化学元素。
在一个实施例中,所述第一权重向量中各个向量元素的值相等,或者,所述第一权重向量中包含单个非零元素。
在一个实施例中,特征融合单元404具体配置为:针对各个对象节点,对其节点特征与其邻居节点的节点特征进行平均处理,得到所述融合特征。
在一个实施例中,特征融合单元404具体配置为:获取所述多个对象节点对应的节点特征矩阵、度矩阵和邻接矩阵;基于所述度矩阵、邻接矩阵和节点特征矩阵进行相乘处理,得到所述融合特征矩阵。
在一个实施例中,所述训练梯度所对应的损失函数基于预测结果与业务标签之间的差值而设定;其中梯度确定单元414具体配置为:针对各个关系网络图,对其所对应的融合特征矩阵、权重向量,以及预测结果与业务标签之间的差值进行相乘处理,得到相乘矩阵;对所述多个关系网络图对应的多个相乘矩阵进行求和,得到求和矩阵;基于所述平均矩阵以及求和矩阵进行相乘处理,得到所述训练梯度。
在一个实施例中,预测单元410具体配置为:将所述多个图表征向量分别输入预测网络,得到所述多个预测结果;参数更新单元416还配置为:基于所述多个预测结果和业务标签,更新所述预测网络中的参数。
在一个具体的实施例中,所述训练梯度所对应的损失函数基于预测结果与业务标签之间的差值而设定,所述预测网络利用参数向量对输入的图表征向量进行线性变换处理;其中梯度确定单元414具体配置为:针对各个关系网络图,对其所对应的融合特征矩阵、权重向量,以及预测结果与业务标签之间的差值进行相乘处理,得到相乘矩阵;对所述多个关系网络图对应的多个相乘矩阵进行求和,得到求和矩阵;基于所述平均矩阵、求和矩阵以及参数向量进行相乘处理,得到所述训练梯度。
在一个更具的实施例中,参数更新单元416进一步配置为:基于所述多个融合特征矩阵、多个权重向量、多个预测结果与多个业务标签之间的多个差值,以及所述图神经网络的当前参数,确定所述参数向量对应的梯度向量,并根据所述梯度向量更新所述预测网络。
综上,采用本说明书实施例披露的图神经网络的训练装置,在训练过程中,避开复杂而耗时的对于激活函数的梯度求解,而使用一种代替该梯度的数值,该数据可以由训练数据直接算得,从而有效减少计算次数和计算耗时,使得图神经网络中的参数呈现可控的线性收敛趋势,快速得到训练好的、嵌入表征性能优异的图神经网络。
如上,根据又一方面的实施例,还提供一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行结合图1或图2所描述的方法。
根据又一方面的实施例,还提供一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现结合图1或图2所描述的方法。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本说明书披露的多个实施例所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本说明书披露的多个实施例的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本说明书披露的多个实施例的具体实施方式而已,并不用于限定本说明书披露的多个实施例的保护范围,凡在本说明书披露的多个实施例的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本说明书披露的多个实施例的保护范围之内。

Claims (20)

1.一种图神经网络的训练方法,包括:
获取关系网络图,其中包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中对象节点携带业务标签;
针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵;
利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数;
基于所述多个嵌入向量,确定多个预测结果;
确定所述融合特征矩阵经过所述激活函数处理前后的乘积矩阵;
基于所述乘积矩阵、多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度;
基于所述训练梯度,更新所述图神经网络中的参数。
2.根据权利要求1所述的方法,其中,所述多个业务对象涉及以下中的至少一种:用户、商品、商户、事件。
3.根据权利要求1所述的方法,其中,针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,包括:
针对所述各个对象节点,对其节点特征与其邻居节点的节点特征进行平均处理,得到所述融合特征。
4.根据权利要求1所述的方法,其中,针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,包括:
获取所述多个对象节点对应的节点特征矩阵、度矩阵和邻接矩阵;
基于所述度矩阵、邻接矩阵和节点特征矩阵进行相乘处理,得到所述融合特征矩阵。
5.根据权利要求1所述的方法,其中,所述训练梯度所对应的损失函数基于所述多个预测结果与多个业务标签之间的差值向量而设定;
其中,基于所述乘积矩阵、多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度,包括:
基于所述乘积矩阵、融合特征矩阵以及差值向量进行相乘处理,得到所述训练梯度。
6.根据权利要求1所述的方法,其中,基于所述多个嵌入向量,确定所述多个对象节点对应的多个预测结果,包括:
将所述多个嵌入向量分别输入预测网络,得到所述多个预测结果;
在得到所述多个预测结果之后,所述方法还包括:
基于所述多个预测结果和业务标签,更新所述预测网络中的参数。
7.根据权利要求6所述的方法,其中,所述训练梯度所对应的损失函数基于所述多个预测结果与多个业务标签之间的差值向量而设定,所述预测网络利用参数向量对输入的嵌入向量进行线性变换处理;
其中,基于所述乘积矩阵、多个预测结果和多个业务标签,确定所述图神经网络中参数的训练梯度,包括:
基于所述乘积矩阵、融合特征矩阵、差值向量和所述参数向量进行相乘处理,得到所述训练梯度。
8.根据权利要求7所述的方法,其中,基于所述多个预测结果和业务标签,更新所述预测网络中的参数,包括:
基于所述融合特征矩阵、所述差值向量,以及所述图神经网络的当前参数,确定所述参数向量对应的梯度向量,并根据所述梯度向量更新所述预测网络。
9.根据权利要求1所述的方法,其中,基于所述训练梯度,更新所述图神经网络中的参数,包括:
确定预设学习率与训练梯度之间的乘积;
将所述图神经网络中的参数更新为其与所述乘积之差。
10.一种图神经网络的训练方法,包括:
获取多个关系网络图,其中任一的第一关系网络图包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中关系网络图携带业务标签;
针对所述第一关系网络图中的各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵;
利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数;
利用所述第一关系网络图对应的第一权重向量,对所述多个嵌入向量进行加权处理,得到所述第一关系网络图对应的图表征向量;
基于多个关系网络图对应的多个图表征向量,确定多个预测结果;
针对各个关系网络图,基于其对应的融合矩阵和权重向量,及其融合矩阵经过所述激活函数处理后得到的矩阵进行乘积处理,得到所述多个关系网络图对应的多个乘积矩阵;
基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度;
基于所述训练梯度,更新所述图神经网络中的参数。
11.根据权利要求10所述的方法,其中,所述多个业务对象涉及以下中的至少一种:用户、商品、商户、事件、化学元素。
12.根据权利要求10所述的方法,其中,所述第一权重向量中各个向量元素的值相等,或者,所述第一权重向量中包含单个非零元素。
13.根据权利要求10所述的方法,其中,所述训练梯度所对应的损失函数基于预测结果与业务标签之间的差值而设定;
其中,基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度,包括:
针对各个关系网络图,对其所对应的融合特征矩阵、权重向量,以及预测结果与业务标签之间的差值进行相乘处理,得到相乘矩阵;
对所述多个关系网络图对应的多个相乘矩阵进行求和,得到求和矩阵;
基于所述平均矩阵以及求和矩阵进行相乘处理,得到所述训练梯度。
14.根据权利要求10所述的方法,其中,基于多个关系网络图对应的多个图表征向量,确定多个预测结果,包括:
将所述多个图表征向量分别输入预测网络,得到所述多个预测结果;
在得到所述多个预测结果之后,所述方法还包括:
基于所述多个预测结果和业务标签,更新所述预测网络中的参数。
15.根据权利要求14所述的方法,其中,所述训练梯度所对应的损失函数基于预测结果与业务标签之间的差值而设定,所述预测网络利用参数向量对输入的图表征向量进行线性变换处理;
其中,基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度,包括:
针对各个关系网络图,对其所对应的融合特征矩阵、权重向量,以及预测结果与业务标签之间的差值进行相乘处理,得到相乘矩阵;
对所述多个关系网络图对应的多个相乘矩阵进行求和,得到求和矩阵;
基于所述平均矩阵、求和矩阵以及参数向量进行相乘处理,得到所述训练梯度。
16.根据权利要求15所述的方法,其中,基于所述多个预测结果和业务标签,更新所述预测网络中的参数,包括:
基于所述多个融合特征矩阵、多个权重向量、多个预测结果与多个业务标签之间的多个差值,以及所述图神经网络的当前参数,确定所述参数向量对应的梯度向量,并根据所述梯度向量更新所述预测网络。
17.一种图神经网络的训练装置,包括:
图谱获取单元,配置为获取关系网络图,其中包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中对象节点携带业务标签;
特征融合单元,配置为针对各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵;
图嵌入单元,配置为利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数;
预测单元,配置为基于所述多个嵌入向量,确定多个预测结果;
矩阵确定单元,配置为确定所述融合特征矩阵经过所述激活函数处理前后的乘积矩阵;
梯度确定单元,配置为基于所述乘积矩阵、多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度;
参数更新单元,配置为基于所述训练梯度,更新所述图神经网络中的参数。
18.一种图神经网络的训练装置,包括:
图谱获取单元,配置为获取多个关系网络图,其中任一的第一关系网络图包括对应多个业务对象的多个对象节点,以及对象节点之间存在关联关系而形成的连接边;其中关系网络图携带业务标签;
特征融合单元,配置为针对所述第一关系网络图中的各个对象节点,将其节点特征与其邻居节点的节点特征进行融合,得到该对象节点的融合特征,所述多个对象节点对应的多个融合特征形成融合特征矩阵;
图嵌入单元,配置为利用所述图神经网络对所述关系网络图进行图嵌入处理,得到所述多个对象节点对应的多个嵌入向量;所述图神经网络中包括激活函数;
图表征单元,配置为利用所述第一关系网络图对应的第一权重向量,对所述多个嵌入向量进行加权处理,得到所述第一关系网络图对应的图表征向量;
预测单元,配置为基于多个关系网络图对应的多个图表征向量,确定多个预测结果;
矩阵确定单元,配置为针对各个关系网络图,基于其对应的融合矩阵和权重向量,及其融合矩阵经过所述激活函数处理后得到的矩阵进行乘积处理,得到所述多个关系网络图对应的多个乘积矩阵;
梯度确定单元,配置为基于所述多个乘积矩阵的平均矩阵,所述多个预测结果和业务标签,确定所述图神经网络中参数的训练梯度;
参数更新单元,配置为基于所述训练梯度,更新所述图神经网络中的参数。
19.一种计算机可读存储介质,其上存储有计算机程序,其中,当所述计算机程序在计算机中执行时,令计算机执行权利要求1-16中任一项的所述的装置。
20.一种计算设备,包括存储器和处理器,其中,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-16中任一项所述的装置。
CN202110177564.2A 2021-02-07 2021-02-07 图神经网络的训练方法及装置 Active CN112766500B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110177564.2A CN112766500B (zh) 2021-02-07 2021-02-07 图神经网络的训练方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110177564.2A CN112766500B (zh) 2021-02-07 2021-02-07 图神经网络的训练方法及装置

Publications (2)

Publication Number Publication Date
CN112766500A true CN112766500A (zh) 2021-05-07
CN112766500B CN112766500B (zh) 2022-05-17

Family

ID=75705400

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110177564.2A Active CN112766500B (zh) 2021-02-07 2021-02-07 图神经网络的训练方法及装置

Country Status (1)

Country Link
CN (1) CN112766500B (zh)

Cited By (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113221153A (zh) * 2021-05-31 2021-08-06 平安科技(深圳)有限公司 图神经网络训练方法、装置、计算设备及存储介质
CN113344177A (zh) * 2021-05-10 2021-09-03 电子科技大学 基于图注意力的深度推荐方法
CN113408706A (zh) * 2021-07-01 2021-09-17 支付宝(杭州)信息技术有限公司 训练用户兴趣挖掘模型、用户兴趣挖掘的方法和装置
CN113850381A (zh) * 2021-09-15 2021-12-28 支付宝(杭州)信息技术有限公司 一种图神经网络训练方法及装置
CN114580794A (zh) * 2022-05-05 2022-06-03 腾讯科技(深圳)有限公司 数据处理方法、装置、程序产品、计算机设备和介质
CN114819139A (zh) * 2022-03-28 2022-07-29 支付宝(杭州)信息技术有限公司 一种图神经网络的预训练方法及装置
CN115221976A (zh) * 2022-08-18 2022-10-21 抖音视界有限公司 一种基于图神经网络的模型训练方法及装置
CN115359654A (zh) * 2022-08-02 2022-11-18 支付宝(杭州)信息技术有限公司 流量预测系统的更新方法及装置
CN115456109A (zh) * 2022-09-30 2022-12-09 中国电力科学研究院有限公司 电网故障元件辨识方法、系统、计算机设备及存储介质
WO2023011237A1 (zh) * 2021-08-04 2023-02-09 支付宝(杭州)信息技术有限公司 业务处理
CN116192662A (zh) * 2023-05-04 2023-05-30 中国电信股份有限公司四川分公司 基于业务行为预测与确定性网络关联模型及推荐方法
CN116562357A (zh) * 2023-07-10 2023-08-08 深圳须弥云图空间科技有限公司 点击预测模型训练方法及装置
CN117273086A (zh) * 2023-11-17 2023-12-22 支付宝(杭州)信息技术有限公司 多方联合训练图神经网络的方法及装置
WO2024021738A1 (zh) * 2022-07-29 2024-02-01 腾讯科技(深圳)有限公司 数据网络图的嵌入方法、装置、计算机设备和存储介质

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20100183225A1 (en) * 2009-01-09 2010-07-22 Rochester Institute Of Technology Methods for adaptive and progressive gradient-based multi-resolution color image segmentation and systems thereof
CN103607768A (zh) * 2013-03-04 2014-02-26 华为技术有限公司 一种非集中式场景下的目标设备定位方法及相关设备
CN110009093A (zh) * 2018-12-07 2019-07-12 阿里巴巴集团控股有限公司 用于分析关系网络图的神经网络系统和方法
CN110929870A (zh) * 2020-02-17 2020-03-27 支付宝(杭州)信息技术有限公司 图神经网络模型训练方法、装置及系统
CN110990871A (zh) * 2019-11-29 2020-04-10 腾讯云计算(北京)有限责任公司 基于人工智能的机器学习模型训练方法、预测方法及装置
CN111309983A (zh) * 2020-03-10 2020-06-19 支付宝(杭州)信息技术有限公司 基于异构图进行业务处理的方法及装置
CN112085615A (zh) * 2020-09-23 2020-12-15 支付宝(杭州)信息技术有限公司 图神经网络的训练方法及装置
CN112085172A (zh) * 2020-09-16 2020-12-15 支付宝(杭州)信息技术有限公司 图神经网络的训练方法及装置
CN112100387A (zh) * 2020-11-13 2020-12-18 支付宝(杭州)信息技术有限公司 用于文本分类的神经网络系统的训练方法及装置

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20100183225A1 (en) * 2009-01-09 2010-07-22 Rochester Institute Of Technology Methods for adaptive and progressive gradient-based multi-resolution color image segmentation and systems thereof
CN103607768A (zh) * 2013-03-04 2014-02-26 华为技术有限公司 一种非集中式场景下的目标设备定位方法及相关设备
CN110009093A (zh) * 2018-12-07 2019-07-12 阿里巴巴集团控股有限公司 用于分析关系网络图的神经网络系统和方法
CN110990871A (zh) * 2019-11-29 2020-04-10 腾讯云计算(北京)有限责任公司 基于人工智能的机器学习模型训练方法、预测方法及装置
CN110929870A (zh) * 2020-02-17 2020-03-27 支付宝(杭州)信息技术有限公司 图神经网络模型训练方法、装置及系统
CN111309983A (zh) * 2020-03-10 2020-06-19 支付宝(杭州)信息技术有限公司 基于异构图进行业务处理的方法及装置
CN112085172A (zh) * 2020-09-16 2020-12-15 支付宝(杭州)信息技术有限公司 图神经网络的训练方法及装置
CN112085615A (zh) * 2020-09-23 2020-12-15 支付宝(杭州)信息技术有限公司 图神经网络的训练方法及装置
CN112100387A (zh) * 2020-11-13 2020-12-18 支付宝(杭州)信息技术有限公司 用于文本分类的神经网络系统的训练方法及装置

Cited By (23)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113344177A (zh) * 2021-05-10 2021-09-03 电子科技大学 基于图注意力的深度推荐方法
CN113221153A (zh) * 2021-05-31 2021-08-06 平安科技(深圳)有限公司 图神经网络训练方法、装置、计算设备及存储介质
CN113408706A (zh) * 2021-07-01 2021-09-17 支付宝(杭州)信息技术有限公司 训练用户兴趣挖掘模型、用户兴趣挖掘的方法和装置
CN113408706B (zh) * 2021-07-01 2022-04-12 支付宝(杭州)信息技术有限公司 训练用户兴趣挖掘模型、用户兴趣挖掘的方法和装置
WO2023011237A1 (zh) * 2021-08-04 2023-02-09 支付宝(杭州)信息技术有限公司 业务处理
CN113850381A (zh) * 2021-09-15 2021-12-28 支付宝(杭州)信息技术有限公司 一种图神经网络训练方法及装置
CN114819139A (zh) * 2022-03-28 2022-07-29 支付宝(杭州)信息技术有限公司 一种图神经网络的预训练方法及装置
CN114580794A (zh) * 2022-05-05 2022-06-03 腾讯科技(深圳)有限公司 数据处理方法、装置、程序产品、计算机设备和介质
CN114580794B (zh) * 2022-05-05 2022-07-22 腾讯科技(深圳)有限公司 数据处理方法、装置、程序产品、计算机设备和介质
WO2023213157A1 (zh) * 2022-05-05 2023-11-09 腾讯科技(深圳)有限公司 数据处理方法、装置、程序产品、计算机设备和介质
WO2024021738A1 (zh) * 2022-07-29 2024-02-01 腾讯科技(深圳)有限公司 数据网络图的嵌入方法、装置、计算机设备和存储介质
CN115359654B (zh) * 2022-08-02 2023-09-08 支付宝(杭州)信息技术有限公司 流量预测系统的更新方法及装置
CN115359654A (zh) * 2022-08-02 2022-11-18 支付宝(杭州)信息技术有限公司 流量预测系统的更新方法及装置
CN115221976A (zh) * 2022-08-18 2022-10-21 抖音视界有限公司 一种基于图神经网络的模型训练方法及装置
CN115221976B (zh) * 2022-08-18 2024-05-24 抖音视界有限公司 一种基于图神经网络的模型训练方法及装置
CN115456109A (zh) * 2022-09-30 2022-12-09 中国电力科学研究院有限公司 电网故障元件辨识方法、系统、计算机设备及存储介质
CN115456109B (zh) * 2022-09-30 2023-11-24 中国电力科学研究院有限公司 电网故障元件辨识方法、系统、计算机设备及存储介质
CN116192662A (zh) * 2023-05-04 2023-05-30 中国电信股份有限公司四川分公司 基于业务行为预测与确定性网络关联模型及推荐方法
CN116192662B (zh) * 2023-05-04 2023-06-23 中国电信股份有限公司四川分公司 基于业务行为预测与确定性网络关联模型及推荐方法
CN116562357A (zh) * 2023-07-10 2023-08-08 深圳须弥云图空间科技有限公司 点击预测模型训练方法及装置
CN116562357B (zh) * 2023-07-10 2023-11-10 深圳须弥云图空间科技有限公司 点击预测模型训练方法及装置
CN117273086A (zh) * 2023-11-17 2023-12-22 支付宝(杭州)信息技术有限公司 多方联合训练图神经网络的方法及装置
CN117273086B (zh) * 2023-11-17 2024-03-08 支付宝(杭州)信息技术有限公司 多方联合训练图神经网络的方法及装置

Also Published As

Publication number Publication date
CN112766500B (zh) 2022-05-17

Similar Documents

Publication Publication Date Title
CN112766500B (zh) 图神经网络的训练方法及装置
Gao et al. Projection-based link prediction in a bipartite network
CN111881350B (zh) 一种基于混合图结构化建模的推荐方法与系统
CN112541575B (zh) 图神经网络的训练方法及装置
CN112085615A (zh) 图神经网络的训练方法及装置
US20230049817A1 (en) Performance-adaptive sampling strategy towards fast and accurate graph neural networks
CN112528110A (zh) 确定实体业务属性的方法及装置
Han et al. On weighted support vector regression
Hu et al. HeteroSales: Utilizing heterogeneous social networks to identify the next enterprise customer
Lyu et al. Memorize, factorize, or be naive: Learning optimal feature interaction methods for CTR prediction
CN113610610B (zh) 基于图神经网络和评论相似度的会话推荐方法和系统
Weng et al. GAIN: Graph attention & interaction network for inductive semi-supervised learning over large-scale graphs
CN110717116B (zh) 关系网络的链接预测方法及系统、设备、存储介质
Öztemiz et al. KO: Modularity optimization in community detection
Sahu et al. Matrix factorization in cross-domain recommendations framework by shared users latent factors
CN114861072B (zh) 一种基于层间组合机制的图卷积网络推荐方法及装置
Ranjith et al. A multi objective teacher-learning-artificial bee colony (MOTLABC) optimization for software requirements selection
JP2023543128A (ja) 動的アテンショングラフネットワークに基づくマーケティング裁定取引ネット暗黒産業の識別方法
Zhou et al. Forecasting credit default risk with graph attention networks
CN112559640A (zh) 图谱表征系统的训练方法及装置
Mohammed et al. Location-aware deep learning-based framework for optimizing cloud consumer quality of service-based service composition
Xie et al. A reinforcement learning approach to optimize discount and reputation tradeoffs in e-commerce systems
Jo et al. AutoGAN-DSP: Stabilizing GAN architecture search with deterministic score predictors
Haotian et al. RECAL: Sample-Relation Guided Confidence Calibration over Tabular Data
Chen et al. Semi-supervised heterogeneous graph learning with multi-level data augmentation

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant