CN111445020B - 一种基于图的卷积网络训练方法、装置及系统 - Google Patents
一种基于图的卷积网络训练方法、装置及系统 Download PDFInfo
- Publication number
- CN111445020B CN111445020B CN201910041725.8A CN201910041725A CN111445020B CN 111445020 B CN111445020 B CN 111445020B CN 201910041725 A CN201910041725 A CN 201910041725A CN 111445020 B CN111445020 B CN 111445020B
- Authority
- CN
- China
- Prior art keywords
- node
- layer
- storage space
- graph
- central node
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种基于图的卷积网络训练方法、装置及系统。所述方法包括:为卷积模型除了最低层和最高层之外的每层建立第一存储空间;基于每个批次的训练数据和图,确定训练数据的各中心节点及中心节点的各邻居节点;针对每个中心节点,从前一层中心节点的各邻居节点标识对应的第一存储空间中获取各邻居节点的表征向量;根据前一层传递来的中心节点的表征向量及获取到的各邻居节点的表征向量,确定中心节点在本层中的表征向量,当当前层并非最低层或者最高层时,将确定出的表征向量传递至本层相邻的下一层并更新本层第一存储空间中中心节点标识对应的表征向量;直至得到中心节点在最高层的表征向量。本发明可有效降低训练的计算量和训练时间。
Description
技术领域
本发明涉及机器学习技术领域,特别涉及一种基于图的卷积网络训练方法、装置及系统。
背景技术
随着移动终端及应用软件的普及,在社交、电商、物流、出行、外卖、营销等领域的服务提供商沉淀了海量业务数据,基于海量业务数据,挖掘不同业务实体(实体)之间的关系成为数据挖掘领域一个重要的技术研究方向。而随着机器处理能力的提升,越来越多技术人员开始研究如何通过机器学习技术进行挖掘。
本发明的发明人发现:目前,通过机器学习技术,对海量业务数据进行学习,得到用于表达实体及实体之间关系的图(Graph),即,对海量业务数据进行图学习,成为一个优选的技术方向。简单理解,图由节点和边构成,图中的每个序号代表一个节点,一个节点用于表示一个实体,节点与节点之间的边用于表示节点之间的关系。一张图一般会包括两个以上的节点和一条以上的边,因此,图也可以理解为由节点的集合和边的集合组成,通常表示为:G(V,E),其中,G表示图,V表示图G中节点的集合,E是图G中边的集合。图可以分为同构图和异构图,其中,异构图指的是一张图中的节点的类型不同(边的类型可以相同或者不同),或者一张图中边的类型不同(节点的类型可以相同或者不同)。图1所示则为一张异构图,同样类型的边用同样的线形表示,同样类型的点用同样的几何图形表示。
图卷积网络(GCN,Graph Convolution Network)是一种有效的图表征的学习方法,在很多关键任务上取得了超过以往的方法的效果。
为了获取图中每个节点的表征向量,图卷积网络通常的做法是对于每个节点聚合其邻居节点的表征向量(表征向量通常是需要训练的GCN各节点的属性特征值),进而得到更高一层节点的结果,依次进行层迭代计算。
容易看到此类方法的计算量随着层数的增多而指数级递增,随着应用的深入,对于节点数量庞大的图来说,可导致模型的整体训练时间不可接受。
高昂的计算力的需求增加了GCN模型训练的时间,制约了GCN模型在各种实际应用例如搜索、广告投放、推荐、社交网络挖掘等方面中的应用。
发明内容
鉴于上述问题,提出了本发明以便提供一种克服上述问题或者至少部分地解决上述问题的一种基于图的卷积网络训练方法、装置及系统。
第一方面,本发明实施例提供一种基于图的卷积网络训练方法,包括:
根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了最低层和最高层之外的每一层建立第一存储空间,每个第一存储空间与一个节点标识对应且用于存储所述节点的特征向量;
基于每个批次的训练数据和所述图,确定所述训练数据的各中心节点的标识及所述中心节点的各邻居节点标识;
针对每个中心节点,从前一层所述中心节点的各邻居节点标识对应的第一存储空间中获取所述各邻居节点的表征向量;
根据前一层传递来的所述中心节点的表征向量及获取到的所述各邻居节点的表征向量,确定所述中心节点在本层中的表征向量,当当前层并非最低层或者最高层时,将确定出的表征向量传递至本层相邻的下一层并更新本层第一存储空间中所述中心节点标识对应的表征向量;
直至得到所述中心节点在最高层的表征向量。
在一个实施例中,邻居节点标识对应的第一存储空间中邻居节点的表征向量为:预先写入所述存储空间中各邻居节点表征向量的初始值;或者为所述邻居节点在其他批次的训练数据中作为中心节点时更新的表征向量。
在一个实施例中,根据前一层传递来的所述中心节点的表征向量及获取到的所述各邻居节点的表征向量,确定所述中心节点在本层中的表征向量,包括:
将所述前一层传递来的所述中心节点的表征向量与获取到的所述各邻居节点的表征向量中各邻居节点的特征向量值进行聚合,得到所述中心节点在本层中的表征向量。
在一个实施例中,上述方法还包括:
根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了最低层和最高层之外的每一层建立第二存储空间,每个第二存储空间与一个节点标识对应且用于存储所述节点的梯度值;
在得到所述中心节点在最高层的表征向量之后,所述方法还包括:
将中心节点在最高层的表征向量和所述中心节点在所述图中的标签label输入预设的损失函数,输出所述中心节点在最高层的梯度值;
通过所述中心节点在最高层的梯度值,更新所述最高层相邻的前一层的各邻居节点的梯度值和所述中心节点的梯度值;当所述前一层并非最低层时,将更新后的所述前一层的各邻居节点的梯度值写入前一层所述邻居节点标识对应的第二存储空间中;
从前一层的中心节点标识对应的第二存储空间中获取所述中心节点的梯度值;
根据从第二存储空间中获取到的中心节点的梯度值及更新后的中心节点的梯度值,得到所述中心节点在所述前一层的梯度值;
循环执行上述流程,直至得到最低层各邻居节点和中心节点的梯度值。
在一个实施例中,上述第二存储空间中所述中心节点的梯度值为:预设的初始梯度值;或者为所述中心节点在其他批次的训练数据中作为邻居节点时更新的梯度值。
在一个实施例中,所述从前一层的中心节点标识对应的第二存储空间中获取所述中心节点的梯度值之后,还包括:
将所述前一层的中心节点标识对应的第二存储空间中的梯度值还原为所述初始梯度值。
第二方面,本发明实施例提供一种基于图的卷积网络训练装置,包括:
存储模块,用于提供卷积模型除了最低层和最高层之外的每一层中节点标识对应的第一存储空间,每个第一存储空间用于存储所述节点的特征向量;所述第一存储空间是根据预定义的图中的节点个数和预设的卷积模型层数预先建立的;
节点确定模块,用于基于每个批次的训练数据和所述图,确定所述训练数据的各中心节点的标识及所述中心节点的各邻居节点标识;
获取模块,用于针对每个中心节点,从前一层所述中心节点的各邻居节点标识对应的第一存储空间中获取所述各邻居节点的表征向量;
表征向量确定模块,用于根据所述中心节点在前一层中的表征向量及获取到的所述各邻居节点的表征向量,确定所述中心节点在本层中的表征向量;直至得到所述中心节点在最高层的表征向量;
向量传递模块,用于当当前层并非最低层或者最高层时,将确定出的表征向量传递至本层相邻的下一层;
更新模块,用于当当前层并非首层或者最后一层时,更新本层第一存储空间中所述中心节点标识对应的表征向量。
第三方面,本发明实施例提供一种计算节点装置,包括:处理器、用于存储处理器可执行命令的存储器;其中,处理器被配置为可执行如权利要求1-6任一项所述的基于图的卷积网络训练方法。
第四方面,本发明实施例提供一种基于图的卷积网络训练系统,包括:至少一个前述计算节点装置和至少两个存储节点装置;
所述计算节点装置,用于从所述至少一个存储节点装置中获取所述图的数据和基于所述图的各批次的训练数据;
所述至少两个存储节点装置,用于分布式地存储所述图,以及提供所述第一存储空间和/或第二存储空间。
第五方面,本发明实施例提供一种非临时性计算机可读存储介质,当所述存储介质中的指令由处理器执行时,能够实现前述基于图的卷积网络训练方法。
本发明实施例提供的上述技术方案的有益效果至少包括:
本发明实施例提供的上述基于图的卷积网络训练方法、装置及系统中,根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了首层和最后一层之外的每一层建立第一存储空间,该第一存储空间用于存储所述节点的特征向量,这样,针对每批次的训练数据中的每个中心节点,在计算其下一层的表征向量时,对于其前一层的邻居节点的表征向量,可直接从对应层的第一存储空间中获取,进而确定出中心节点在本层的表征向量,并更新至本层中心节点标识对应的第一存储空间中,以便在后续其他批次的训练数据中该节点作为其他中心节点的邻居节点时直接进行复用,这样,针对数据量庞大的训练数据而言,每批次训练数据,只需要计算每批次训练数据中中心节点的卷积结果,而无需对扩散后的所有邻居节点再进行卷积计算,可实现计算量与层数呈线性关系,降低了训练的计算量和训练时间,避免了现有训练过程中计算量与层数呈指数关系所带来的计算量庞大、整体训练时间不可接受的问题。
另外,由于在卷积网络训练过程的前向计算中采用了对于其前一层的邻居节点的表征向量,直接从对应层的第一存储空间中获取的方式,而未实际对每层的邻居节点向下进行扩展,因此,在对应的反向计算过程中,为了解决邻居节点的梯度反向传递问题,本发明还根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了最低层和最高层之外的每一层建立了第二存储空间,该第二存储空间用于存储节点的梯度值;这样,对最高层和非最低层而言,可根据本层中心节点的梯度,更新前一层中心节点和各邻居节点的梯度值,并将前一层邻居节点的梯度保存于前一层预设的第二存储空间中。再根据更新后的中心节点的梯度值与从该前一层该中心节点标识对应的第二存储空间获取该中心节点的梯度值,最终得到用于向下传递的该中心节点的梯度值。如此循环,使用前一层中心节点的梯度值继续向前一层相邻的再前一层传递,直至最低层。这样,在该邻居节点在其他批次训练数据中再次作为中心节点时,可以顺利地将梯度传递至前一层的中心节点和各邻居节点,以实现整个卷积网络训练中梯度值的反向传递。
本发明的其它特征和优点将在随后的说明书中阐述,并且,部分地从说明书中变得显而易见,或者通过实施本发明而了解。本发明的目的和其他优点可通过在所写的说明书、权利要求书、以及附图中所特别指出的结构来实现和获得。
下面通过附图和实施例,对本发明的技术方案做进一步的详细描述。
附图说明
附图用来提供对本发明的进一步理解,并且构成说明书的一部分,与本发明的实施例一起用于解释本发明,并不构成对本发明的限制。在附图中:
图1为异构图的一个例子的示意图;
图2为现有技术中图卷积网络模型的训练方式的示意图;
图3为本发明实施例提供的基于图的卷积网络训练方法的流程图之一;
图4为本发明实施例提供的实例中图卷积网络的模型的前向计算的过程示意图;
图5为本发明实施例提供的基于图的卷积网络训练方法的流程图之二;
图6为本发明实施例提供的基于图的卷积网络训练方法的流程图之三;
图7为本发明实施例提供的实例中图卷积网络的模型的后向计算的过程示意图;
图8为本发明实施例提供的基于图的卷积网络训练装置的结构示意图。
具体实施方式
下面将参照附图更详细地描述本公开的示例性实施例。虽然附图中显示了本公开的示例性实施例,然而应当理解,可以以各种形式实现本公开而不应被这里阐述的实施例所限制。相反,提供这些实施例是为了能够更透彻地理解本公开,并且能够将本公开的范围完整的传达给本领域的技术人员。
图学习面向业务场景才是有意义的,所以,当基于业务场景确定了图中的节点对应的实体和边对应的实体间关系后,图则被赋予了业务含义和技术含义,按照该业务场景要解决的技术问题和业务问题执行相应的图学习任务,则可得到解决相应问题的结果。比如,图表示学习可将复杂的图表示成低维、实值、稠密的向量形式,使其具有表示及推理能力,可以方便执行其他机器学习任务。
现有技术中,为了获取图中每个节点的表征向量,可能采用多种学习方法,图卷积网络(GCN,Graph Convolution Network)就是其中之一,图卷积网络通常的做法是对于每个节点聚合其邻居节点的表征向量(表征向量通常是需要训练的GCN各节点的属性特征值),进而得到更高一层节点的结果,然后依次进行层迭代计算。
此类方法的计算量可随着层数的增多而指数级递增,参照图2所示的图卷积网络的模型,假设使用每层度为2的邻居节点采样来实现h(0)~h(3)层的卷积网络,则对于每个节点来说,其每层的表征向量都依赖于前一层的包括自身和其2个邻居节点在内的3个节点的表征向量;比如,每个h(1)的表征向量,都依赖于前一层即3个h(0)的表征向量来计算,以此类推,为了获取最后一层h(3)节点的表征向量,从计算量上来说,则需要计算15个h(0)、7个h(1)、3个h(2)和1个h(3)的节点的相关表征向量。而为了在规模图上降低采样带来的精度损失,需要更宽的邻居节点采样,也就是每个节点其前一层需要参与计算的邻居节点的数量更多,例如为5个或10个,这样话,各节点的表征向量的计算量膨胀会更为显著,导致模型的整体训练时间不可接受。
为了解决上述基于图的卷积网络的训练问题中存在的计算量庞大所带来的训练时间不可接受的问题,本发明实施例提供了一种基于图的卷积网络训练的方法,该方法可以适用于任何种类的图,包括各种不同类型的同构图和异构图。在这些图中,包含若干彼此相连的节点,每个节点存储有自身与其他节点的连接关系、自身属性和与其他节点相连的边的属性等信息,这些节点可以根据图的类型和用途的不同,而可以呈现为各种实体对象的表现形式,例如为商品购买网络中的商品、查询关键词、用户、搜索广告等;或者为社交网络中的用户、用户偏好等。本发明实施例对图和节点的所表征的是何种对象并不限定。
下面结合附图,对本发明实施例提供一种基于图的卷积网络训练方法的具体实施方式进行详细说明。
参照图3所示,本发明实施例提供的基于图的卷积网络训练方法,包括下述步骤:
S31、根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了最低层和最高层之外的每一层建立第一存储空间,每个第一存储空间与一个节点标识对应且用于存储节点的特征向量。
本发明的发明人发现,由于考虑到不同批次中心节点在不同批次的训练数据被大量复用的情形,本发明实施例对GCN的训练过程中为图卷积网络模型引入了第一存储空间,在缓存中直接查询前一层邻居节点的表征向量,参与本层中心节点的表征向量的运算,而不用对前一层邻居节点进行实际的扩散得到其表征向量,因此实现计算量与层数呈线性关系,大大降低GCN的训练时间。
本步骤S31可以在下述步骤S32~S36之前预先完成。
在本发明实施例中,步骤S31~S36对图卷积网络学习的前向运算进行了改进。前向运算是指从GCN的最低层,逐步学习到最高层的中心节点的表征向量的过程。
在本发明实施例中,仅为了区分存储空间所存储的内容的不同,将与卷积模型中除了最低层和最高层之外的每一层所对应的存储空间分为两类,分别称呼为第一存储空间和第二存储空间,第一存储空间用于存储在图卷积网络训练过程前向运算过程中需要使用的节点的特征向量;而第二存储空间用于存储后向运算过程中需要使用的节点的梯度值,随后会在GCN后向运算过程中说明。
在具体实施时,上述第一存储空间和第二存储空间在物理上可以是一体的,也就是可以位于同一个物理存储区域上,也可以是在物理上分隔的不同区域,即分属于不同的物理存储区域。
具体实施时,第一存储空间和第二存储空间可以采用各类缓存等,方便根据节点的数量和图卷积网络模型的层数灵活设置空间大小,并在图卷积网络模型学习结束后及时释放存储空间,具有较强的灵活性,本发明实施例对采用何种缓存并不做限定。
图卷积网络模型的层数,决定了需要几层的第一存储空间。举例来说,假设图卷积网络模型的层数为n,则除了最低层和最高层之外,一共有n-2层需要设置对应的第一存储空间。
对于每一层来说,预定义的图中有多少节点,不论是同构节点还是异构的节点,就需要有多少个第一存储空间,每个第一存储空间都分别与一个节点标识对应。
假设图中一共有10万个节点,需要设置第一存储空间的层数是2层,那么,则一共需要20万个第一存储空间。每个层的第一存储空间都是独立的。
针对一个预定义的图来说,每批次的训练数据中包含了该批次中各中心节点的标识,各个批次的中心节点则涵盖了整个图中的所有节点,举个简单例子来说,假设图中有100个节点,分为10个批次的训练数据,则每个批次的训练数据中就包含了10个中心节点。
S32、基于每个批次的训练数据和图,确定训练数据的各中心节点的标识及该中心节点的各邻居节点标识;
对于包含数量庞大的节点的图来说,由于计算能力有限,通常可以采用分批次的训练数据进行训练的方式。例如划分为若干个min-batch,每个min-batch作为一个批次的训练数据。每个批次的训练数据中,都包含了本批次的各中心节点的标识。
根据每批训练数据中中心节点的标识以及图,就可以确定每批次训练数据中每个中心节点的各邻居节点的标识。
S33、针对每个中心节点,从前一层中心节点的各邻居节点标识对应的第一存储空间中获取各邻居节点的表征向量;
在本步骤S33中,每个中心节点的计算方式与现有技术不同,是直接从前一层的第一存储空间中,找到该中心节点的各邻居节点的标识所对应的那些第一存储空间,从其中读取这些邻居节点在前一层的表征向量。
在本发明实施例中,在GCN开始训练之前,每个第一存储空间中都存有该节点在该层的表征向量的一个初始值,这个预设的初始值可以是个固定的数值,也可以是个随机生成的随机数。随着训练的开始,不同批次的中心节点在训练时,作为中心节点的节点对应的第一存储空间的数据逐渐被更新。
因此,从第一存储空间读出表征向量,可以是下面两种情形,一种是,这个第一存储空间的数据从未被更新过,还是预设的初始值,换言之,这个邻居节点在本批次数据之前未作为中心节点更新过其自身的表征向量。另一种是,这个邻居节点在本批次之前已作为其他批次的中心节点,将初始值更新为计算后的表征向量值。
S34、根据前一层传递来的中心节点的表征向量及获取到的各邻居节点的表征向量,确定中心节点在本层中的表征向量;
在本步骤S34来说,可以采用将前一层传递来的所述中心节点的表征向量与获取到的所述各邻居节点的表征向量中各邻居节点的特征向量值进行聚合的方式,得到中心节点在本层中的表征向量。
用数学的方式例如可以表达为:
要计算中心节点v在l+1层的表征向量,上述公式1是将该中心节点v在第l层的所有邻居节点的表征向量(从对应的第一存储空间中获取)先进行预设的Aggregation算法的聚合运算,然后再通过公式2,再将前面的聚合的结果再与该中心节点在第l层的表征向量再次进行另一类型的combine算法的聚合运算。
当然,还可以一次性将中心节点v在第l层的所有邻居节点的表征向量和中心节点在第l层的表征向量进行一次聚合运算,本发明实施例对采用何种聚合的方式不做限定。
聚合算法可以采用现有任何一种算法,本发明实施例对此不做限定,例如采用mean(平均)、加权平均等等。
S35、判断当前层是否为非最低层或者最高层,当判断为非最低层或者非最高层时,执行下述步骤S36;否则,结束流程;
由于步骤S31~S36是前向计算过程,即最低层逐层向最高层运算的过程,所以在本步骤S35中,当判断当前层为最高层时,其实就是步骤S34中已计算得到了中心节点在最高层中的表征向量,此时前向计算过程就结束了。
S36、将确定出的表征向量传递至本层相邻的下一层并更新本层第一存储空间中中心节点标识对应的表征向量;然后再转向执行步骤S33,循环步骤S33~S36,直到流程结束。
中心节点在非最低层的表征向量,都可以从上一层计算后直接传递过来,而不需要从上一层对应的第一存储空间中读取,在本步骤S26中,将计算出来的中心节点传递至相邻的下一层时,可以同时更新本层第一存储空间中的存储的表征向量。
为了更好地说明上述步骤S31~S36,下面以一个简单的图卷积网络模型的训练过程中的前向计算过程进行说明:
参照图4所示,该图卷积网络的模型共包含h0~h3层,位于每层的中间的黑色节点为中心节点,在每层中,与黑色节点连接的灰色的3个节点为其邻居节点。非最低层和非最高层的h1层和h2层分别预设有对应的缓存空间和/>缓存空间/>和/>中包含若干小的缓存,分别存储每个节点在本层的表征向量。
最低层h0的中心节点和邻居节点的表征向量,是原始值,可以直接从图中得到。
对于h1层的中心节点的表征向量,是通过h0层的3个邻居节点和中心节点的表征向量进行聚合得到,然后将h1层的中心节点的表征向量写入对应的缓存中(见图4中h1层中心节点指向缓存空间/>的箭头),以便在下次该中心节点作为其他批次的训练数据中的邻居节点时使用,同时,得到的h1层的中心节点的表征向量继续向相邻的下一层传递(见图4中h1层中心节点指向缓存空间下一层h2中心节点的箭头)。
h2层的中心节点的表征向量的得到方式与前面方式类似,即:先是从缓存空间中读取中心节点的3个邻居节点在h1层的表征向量,然后将读取到的3个邻居节点在h1层的表征向量与从h1层传递上来的该中心节点在h1层的表征向量进行聚合,得到中心节点在h2层的表征向量,然后再将中心节点在h2层的表征向量写入到缓存空间/>对应的位置。
h3层的中心节点的表征向量的得到方式与前面的方式类似,不同的在于,在得到中心节点在h3层的表征向量后,不需要将其更新到对应的缓存空间中。
本发明实施例提供的上述基于图的卷积网络训练方法中,根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了首层和最后一层之外的每一层建立第一存储空间,该第一存储空间用于存储所述节点的特征向量,这样,针对每批次的训练数据中的每个中心节点,在计算其下一层的表征向量时,对于其前一层的邻居节点的表征向量,可直接从对应层的第一存储空间中获取,进而确定出中心节点在本层的表征向量,并更新至本层中心节点标识对应的第一存储空间中,以便在后续其他批次的训练数据中该节点作为其他中心节点的邻居节点时直接进行复用,这样,针对数据量庞大的训练数据而言,每批次训练数据,只需要计算每批次训练数据中中心节点的卷积结果,而无需对扩散后的所有邻居节点再进行卷积计算,可实现计算量与层数呈线性关系,降低了训练的计算量和训练时间,避免了现有训练过程中计算量与层数呈指数关系所带来的计算量庞大、整体训练时间不可接受的问题。
对于每批次训练数据来说,图卷积网络模型的训练通常会包含前向计算过程,即获得最高层的中心节点的表征向量,后续还需要按照最高层中心节点的表征向量进行反向梯度的传递,也就是后向计算(Back Propagation)的过程。
由于在卷积网络训练过程的前向计算中采用了对于其前一层的邻居节点的表征向量,直接从对应层的第一存储空间中获取的方式,而未实际对每层的邻居节点向下进行扩展,所以,现有GCN后向计算过程无法适用于本发明实施例,为了解决后向计算的过程的上述问题,本发明实施例提供的卷积网络训练方法,在上述S31~S36之后,参照图5所示,还可以执行下述流程:
S51、将中心节点在最高层的表征向量和中心节点在图中的标签label输入预设的损失函数,输出中心节点在最高层的梯度值;
与步骤S31类似,根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了最低层和最高层之外的每一层建立第二存储空间,每个第二存储空间与一个节点标识对应且用于存储节点的梯度值。
S52、从最高层开始,在每层,通过中心节点在本层的梯度值,更新本层相邻的前一层的各邻居节点的梯度值和中心节点的梯度值;
S53、判断前一层是否为最低层,若是,结束流程,若否,继续执行下述步骤S54;
S54、将更新后的前一层的各邻居节点的梯度值写入前一层邻居节点标识对应的第二存储空间中;
S55、从前一层的中心节点标识对应的第二存储空间中获取中心节点的梯度值;
S56、根据从第二存储空间中获取到的中心节点的梯度值及更新后的中心节点的梯度值,最终得到中心节点在最高层的前一层的梯度值,这个梯度值用于继续向再前一层传递;然后跳转至步骤S52,再次循环,直至得到最低层的各邻居节点的梯度值和中心节点的梯度值。
由于后向计算过程中,中心节点的梯度是可以直接向下传递的,而作为邻居节点,因为没有实际扩展,没办法向下传递,所以在上述步骤S51~S56中,将邻居节点在每层(非最高层和非最低层)的梯度进行缓存,以便这些邻居节点后续在其他批次中作为中心节点中可以直接使用这些梯度数据,实现梯度向下传递。
在卷积模型除了最低层和最高层之外的每一层建立第二存储空间中,在训练初始时,其中每个节点标识对应的第二存储空间存储的是各节点梯度的初始值。随着每批次的训练数据的训练过程的进行,在后向计算过程中,作为邻居节点的节点对应的第二存储空间中的初始值逐渐被更新为其他数值,从而可以在其他批次的训练中作为中心节点时,其梯度值可以进行后向传递。
节点梯度的初始值,较佳地,可以设置为0。
参照图6所示,在上述步骤S55中从前一层的中心节点标识对应的第二存储空间中获取中心节点的梯度值之后,上述流程还可以包括下述步骤:
S57、将前一层的中心节点标识对应的第二存储空间中的梯度值还原为初始梯度值。
还是以一个图4所示的GCN模型的例子来说明,参照图7所示,该图卷积网络的模型共包含h0~h3层,黑色节点是中心节点,灰色节点是其邻居节点。非最低层和非最高层的h1层和h2层分别预设有对应的缓存空间和/>缓存空间/>和/>中包含若干小的缓存,分别存储每个节点在本层的梯度值。
5、将从缓存空间中读取的中心节点的梯度值,与步骤2中中心节点更新后的梯度值再次进行运算,得到h2层中心节点最终的梯度值(见图7中h2层中心节点指向BackProp的箭头),也就是用于继续向下传递的/>
上述从h2层中心节点对应的缓存空间中读取其中的梯度值,该梯度值是中心节点的累计未消费梯度,实际上是该中心节点在作为邻居节点时的梯度进行延迟的后向计算,在消费掉该梯度值后(也就是读取了该梯度值后),将其值重置为初始值(例如为0)。
采用本发明实施例提供的上述基于图的卷积网络训练方法,本发明的发明人经过试验发现,相对于现有技术,本发明实施例可大大缩短GCN模型的训练时间。
试验中使用512个min-batch训练20个神经网络模型,并在开源数据集上验证上述方法的执行效率。
采用本发明实施例提供的上述训练方法称为ScalableGCN,现有技术中GCN训练的方式选择了GraphSAGE。
两种方式训练的结果相差很小,因此采用本发明实施例提供的上述方法可以取得多层卷积网络模型的收益,但是在训练时间上,下表1是以Reddi数据集(23万个节点)上每个min-batch两种方式所需训练的时间的对比:
表1 单位秒
GCN模型类型 | ScalableGCN | GraphSAGE |
2layer(2层) | 0.026 | 0.120 |
3layer(3层) | 0.035 | 1.119 |
注意到,采用本发明实施例的上述方法,可以大大每批次压缩训练的时间,对GCN训练的时间相对于卷积模型层数是呈线性的。
基于同一发明构思,本发明实施例还提供了一种基于图的卷积网络训练装置、计算节点装置和基于图的卷积网络训练系统,由于这些装置和系统所解决问题的原理与前述基于图的卷积网络训练方法相似,因此该装置和系统的实施可以参见前述方法的实施,重复之处不再赘述。
本发明实施例提供的基于图的卷积网络训练装置,参照图8所示,包括:
存储模块81,用于提供卷积模型除了最低层和最高层之外的每一层中节点标识对应的第一存储空间,每个第一存储空间用于存储所述节点的特征向量;所述第一存储空间是根据预定义的图中的节点个数和预设的卷积模型层数预先建立的;
节点确定模块82,用于基于每个批次的训练数据和所述图,确定所述训练数据的各中心节点的标识及所述中心节点的各邻居节点标识;
获取模块83,用于针对每个中心节点,从前一层所述中心节点的各邻居节点标识对应的第一存储空间中获取所述各邻居节点的表征向量;
表征向量确定模块84,用于根据所述中心节点在前一层中的表征向量及获取到的所述各邻居节点的表征向量,确定所述中心节点在本层中的表征向量;直至得到所述中心节点在最高层的表征向量;
向量传递模块85,用于当当前层并非最低层或者最高层时,将确定出的表征向量传递至本层相邻的下一层;
更新模块86,用于当当前层并非首层或者最后一层时,更新本层第一存储空间中所述中心节点标识对应的表征向量。
在一个实施例中,上述邻居节点标识对应的第一存储空间中邻居节点的表征向量为:预先写入所述存储空间中各邻居节点表征向量的初始值;或者为所述邻居节点在其他批次的训练数据中作为中心节点时更新的表征向量。
在一个实施例中,表征向量确定模块84,具体用于将所述前一层传递来的所述中心节点的表征向量与获取到的所述各邻居节点的表征向量中各邻居节点的特征向量值进行聚合,得到所述中心节点在本层中的表征向量。
在一个实施例中,上述基于图的卷积网络训练装置,参照图8所示,还可以包括:梯度反向传递模块87;
相应地,所述存储模块81,还用于提供卷积模型除了最低层和最高层之外的每一层中节点标识对应的第二存储空间,每个第二存储空间用于存储所述节点的梯度值;所述第二存储空间是根据预定义的图中的节点个数和预设的卷积模型层数预先建立的;
梯度反向传递模块87,用于将中心节点在最高层的表征向量和所述中心节点在所述图中的标签label输入预设的损失函数,输出所述中心节点在最高层的梯度值;通过所述中心节点在最高层的梯度值,更新所述最高层相邻的前一层的各邻居节点的梯度值和所述中心节点的梯度值;当所述前一层并非最低层时,将更新后的所述前一层的各邻居节点的梯度值写入前一层所述邻居节点标识对应的第二存储空间中;从前一层的中心节点标识对应的第二存储空间中获取所述中心节点的梯度值;根据从第二存储空间中获取到的中心节点的梯度值及更新后的中心节点的梯度值,得到所述中心节点在所述前一层的梯度值;循环执行上述流程,直至得到最低层各邻居节点和中心节点的梯度值。
本发明实施例还提供了一种计算节点装置,包括:处理器、用于存储处理器可执行命令的存储器;其中,处理器被配置为可执行前述基于图的卷积网络训练方法。
本发明实施例还提供了一种基于图的卷积网络训练系统,包括:至少一个前述的计算节点装置和至少两个存储节点装置;
所述计算节点装置,用于从所述至少一个存储节点装置中获取所述图的数据和基于所述图的各批次的训练数据;
所述至少两个存储节点装置,用于分布式地存储所述图,以及提供所述第一存储空间和/或第二存储空间。
本发明实施例还提供了一种非临时性计算机可读存储介质,当所述存储介质中的指令由处理器执行时,能够实现前述基于图的卷积网络训练方法。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器和光学存储器等)上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
显然,本领域的技术人员可以对本发明进行各种改动和变型而不脱离本发明的精神和范围。这样,倘若本发明的这些修改和变型属于本发明权利要求及其等同技术的范围之内,则本发明也意图包含这些改动和变型在内。
Claims (6)
1.一种基于图的卷积网络训练系统,其特征在于,包括:至少一个计算节点装置和至少两个存储节点装置;
所述计算节点装置,包括:处理器、用于存储处理器可执行命令的存储器;其中,处理器被配置为可执行基于图的卷积网络训练方法;用于从所述至少一个存储节点装置中获取所述图的数据和基于所述图的各批次的训练数据;
所述至少两个存储节点装置,用于分布式地存储所述图,以及提供第一存储空间和/或第二存储空间;
所述卷积网络训练方法包括下述步骤:
根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了最低层和最高层之外的每一层建立第一存储空间,每个第一存储空间与一个节点标识对应且用于存储所述节点的特征向量;
基于每个批次的训练数据和所述图,确定所述训练数据的各中心节点的标识及所述中心节点的各邻居节点标识;
针对每个中心节点,从前一层所述中心节点的各邻居节点标识对应的第一存储空间中获取所述各邻居节点的表征向量;
根据前一层传递来的所述中心节点的表征向量及获取到的所述各邻居节点的表征向量,确定所述中心节点在本层中的表征向量,当当前层并非最低层或者最高层时,将确定出的表征向量传递至本层相邻的下一层并更新本层第一存储空间中所述中心节点标识对应的表征向量;
直至得到所述中心节点在最高层的表征向量。
2.如权利要求1所述的系统,其特征在于,邻居节点标识对应的第一存储空间中邻居节点的表征向量为:预先写入所述存储空间中各邻居节点表征向量的初始值;或者为所述邻居节点在其他批次的训练数据中作为中心节点时更新的表征向量。
3.如权利要求1所述的系统,其特征在于,根据前一层传递来的所述中心节点的表征向量及获取到的所述各邻居节点的表征向量,确定所述中心节点在本层中的表征向量,包括:
将所述前一层传递来的所述中心节点的表征向量与获取到的所述各邻居节点的表征向量中各邻居节点的特征向量值进行聚合,得到所述中心节点在本层中的表征向量。
4.如权利要求1-3任一项所述的系统,其特征在于,还包括:
根据预定义的图中的节点个数和预设的卷积模型层数,为卷积模型除了最低层和最高层之外的每一层建立第二存储空间,每个第二存储空间与一个节点标识对应且用于存储所述节点的梯度值;
在得到所述中心节点在最高层的表征向量之后,所述方法还包括:
将中心节点在最高层的表征向量和所述中心节点在所述图中的标签label 输入预设的损失函数,输出所述中心节点在最高层的梯度值;
通过所述中心节点在最高层的梯度值,更新所述最高层相邻的前一层的各邻居节点的梯度值和所述中心节点的梯度值;当所述前一层并非最低层时,将更新后的所述前一层的各邻居节点的梯度值写入前一层所述邻居节点标识对应的第二存储空间中;
从前一层的中心节点标识对应的第二存储空间中获取所述中心节点的梯度值;
根据从第二存储空间中获取到的中心节点的梯度值及更新后的中心节点的梯度值,得到所述中心节点在所述前一层的梯度值;
循环执行上述流程,直至得到最低层各邻居节点和中心节点的梯度值。
5.如权利要求4所述的系统,其特征在于,第二存储空间中所述中心节点的梯度值为:预设的初始梯度值;或者为所述中心节点在其他批次的训练数据中作为邻居节点时更新的梯度值。
6.如权利要求5所述的系统,其特征在于,所述从前一层的中心节点标识对应的第二存储空间中获取所述中心节点的梯度值之后,还包括:
将所述前一层的中心节点标识对应的第二存储空间中的梯度值还原为所述初始梯度值。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910041725.8A CN111445020B (zh) | 2019-01-16 | 2019-01-16 | 一种基于图的卷积网络训练方法、装置及系统 |
PCT/CN2020/070584 WO2020147612A1 (zh) | 2019-01-16 | 2020-01-07 | 一种基于图的卷积网络训练方法、装置及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910041725.8A CN111445020B (zh) | 2019-01-16 | 2019-01-16 | 一种基于图的卷积网络训练方法、装置及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111445020A CN111445020A (zh) | 2020-07-24 |
CN111445020B true CN111445020B (zh) | 2023-05-23 |
Family
ID=71613689
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910041725.8A Active CN111445020B (zh) | 2019-01-16 | 2019-01-16 | 一种基于图的卷积网络训练方法、装置及系统 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN111445020B (zh) |
WO (1) | WO2020147612A1 (zh) |
Families Citing this family (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112070216B (zh) * | 2020-09-29 | 2023-06-02 | 支付宝(杭州)信息技术有限公司 | 一种基于图计算系统训练图神经网络模型的方法及系统 |
CN112035683A (zh) * | 2020-09-30 | 2020-12-04 | 北京百度网讯科技有限公司 | 用户交互信息处理模型生成方法和用户交互信息处理方法 |
CN112347362B (zh) * | 2020-11-16 | 2022-05-03 | 安徽农业大学 | 一种基于图自编码器的个性化推荐方法 |
CN112765373B (zh) * | 2021-01-29 | 2023-03-21 | 北京达佳互联信息技术有限公司 | 资源推荐方法、装置、电子设备和存储介质 |
CN112800186B (zh) * | 2021-04-08 | 2021-10-12 | 北京金山数字娱乐科技有限公司 | 阅读理解模型的训练方法及装置、阅读理解方法及装置 |
CN113343121B (zh) * | 2021-06-02 | 2022-08-09 | 合肥工业大学 | 基于多粒度流行度特征的轻量级图卷积协同过滤推荐方法 |
CN113255844B (zh) * | 2021-07-06 | 2021-12-10 | 中国传媒大学 | 基于图卷积神经网络交互的推荐方法及系统 |
CN113642452B (zh) * | 2021-08-10 | 2023-11-21 | 汇纳科技股份有限公司 | 人体图像质量评价方法、装置、系统及存储介质 |
CN113835899B (zh) * | 2021-11-25 | 2022-02-22 | 支付宝(杭州)信息技术有限公司 | 针对分布式图学习的数据融合方法及装置 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106682734A (zh) * | 2016-12-30 | 2017-05-17 | 中国科学院深圳先进技术研究院 | 一种提升卷积神经网络泛化能力的方法及装置 |
CN108492200A (zh) * | 2018-02-07 | 2018-09-04 | 中国科学院信息工程研究所 | 一种基于卷积神经网络的用户属性推断方法和装置 |
CN108664687A (zh) * | 2018-03-22 | 2018-10-16 | 浙江工业大学 | 一种基于深度学习的工控系统时空数据预测方法 |
CN108776975A (zh) * | 2018-05-29 | 2018-11-09 | 安徽大学 | 一种基于半监督特征和滤波器联合学习的视觉跟踪方法 |
CN109033738A (zh) * | 2018-07-09 | 2018-12-18 | 湖南大学 | 一种基于深度学习的药物活性预测方法 |
Family Cites Families (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US7293000B2 (en) * | 2002-02-22 | 2007-11-06 | Lee Shih-Jong J | Information integration method for decision regulation in hierarchic decision systems |
US9799098B2 (en) * | 2007-04-24 | 2017-10-24 | Massachusetts Institute Of Technology | Method and apparatus for image processing |
US10192162B2 (en) * | 2015-05-21 | 2019-01-29 | Google Llc | Vector computation unit in a neural network processor |
CN108229455B (zh) * | 2017-02-23 | 2020-10-16 | 北京市商汤科技开发有限公司 | 物体检测方法、神经网络的训练方法、装置和电子设备 |
CN108648095A (zh) * | 2018-05-10 | 2018-10-12 | 浙江工业大学 | 一种基于图卷积网络梯度的节点信息隐藏方法 |
-
2019
- 2019-01-16 CN CN201910041725.8A patent/CN111445020B/zh active Active
-
2020
- 2020-01-07 WO PCT/CN2020/070584 patent/WO2020147612A1/zh active Application Filing
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106682734A (zh) * | 2016-12-30 | 2017-05-17 | 中国科学院深圳先进技术研究院 | 一种提升卷积神经网络泛化能力的方法及装置 |
CN108492200A (zh) * | 2018-02-07 | 2018-09-04 | 中国科学院信息工程研究所 | 一种基于卷积神经网络的用户属性推断方法和装置 |
CN108664687A (zh) * | 2018-03-22 | 2018-10-16 | 浙江工业大学 | 一种基于深度学习的工控系统时空数据预测方法 |
CN108776975A (zh) * | 2018-05-29 | 2018-11-09 | 安徽大学 | 一种基于半监督特征和滤波器联合学习的视觉跟踪方法 |
CN109033738A (zh) * | 2018-07-09 | 2018-12-18 | 湖南大学 | 一种基于深度学习的药物活性预测方法 |
Non-Patent Citations (3)
Title |
---|
Mathias Niepert等.TOWARDS A SPECTRUM OF GRAPH CONVOLUTIONAL NETWORKS.2018 IEEE Data Science Workshop (DSW).2018,第244-248页. * |
王晓斌 ; 黄金杰 ; 刘文举 ; .基于优化卷积神经网络结构的交通标志识别.计算机应用.2017,(第02期),全文. * |
郎泽宇.基于卷积神经网络的水下目标特征提取方法研究.中国优秀硕士学位论文全文数据库.2018,全文. * |
Also Published As
Publication number | Publication date |
---|---|
CN111445020A (zh) | 2020-07-24 |
WO2020147612A1 (zh) | 2020-07-23 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111445020B (zh) | 一种基于图的卷积网络训练方法、装置及系统 | |
US11227190B1 (en) | Graph neural network training methods and systems | |
JP7322044B2 (ja) | レコメンダシステムのための高効率畳み込みネットワーク | |
CN110263280B (zh) | 一种基于多视图的动态链路预测深度模型及应用 | |
CN109033107B (zh) | 图像检索方法和装置、计算机设备和存储介质 | |
CN111708876B (zh) | 生成信息的方法和装置 | |
CN111444394A (zh) | 获取实体间关系表达的方法、系统和设备、广告召回系统 | |
CN111310074B (zh) | 兴趣点的标签优化方法、装置、电子设备和计算机可读介质 | |
CN102831129B (zh) | 一种基于多示例学习的检索方法及系统 | |
US20220253722A1 (en) | Recommendation system with adaptive thresholds for neighborhood selection | |
US20220138502A1 (en) | Graph neural network training methods and systems | |
CN116010684A (zh) | 物品推荐方法、装置及存储介质 | |
CN115688913A (zh) | 一种云边端协同个性化联邦学习方法、系统、设备及介质 | |
CN111506820A (zh) | 推荐模型、方法、装置、设备及存储介质 | |
CN109754135B (zh) | 信用行为数据处理方法、装置、存储介质和计算机设备 | |
WO2022252694A1 (zh) | 神经网络优化方法及其装置 | |
CN110674181A (zh) | 信息推荐方法、装置、电子设备及计算机可读存储介质 | |
CN115730217A (zh) | 模型的训练方法、物料的召回方法及装置 | |
CN111935259B (zh) | 目标帐号集合的确定方法和装置、存储介质及电子设备 | |
CN114611668A (zh) | 一种基于异质信息网络随机游走的向量表示学习方法及系统 | |
CN114493674A (zh) | 一种广告点击率预测模型及方法 | |
CN112685603A (zh) | 顶级相似性表示的有效检索 | |
CN111091198A (zh) | 一种数据处理方法及装置 | |
CN111292171B (zh) | 金融理财产品推送方法及装置 | |
CN111931058B (zh) | 一种基于自适应网络深度的序列推荐方法和系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |