CN113222143B - 图神经网络训练方法、系统、计算机设备及存储介质 - Google Patents

图神经网络训练方法、系统、计算机设备及存储介质 Download PDF

Info

Publication number
CN113222143B
CN113222143B CN202110602892.2A CN202110602892A CN113222143B CN 113222143 B CN113222143 B CN 113222143B CN 202110602892 A CN202110602892 A CN 202110602892A CN 113222143 B CN113222143 B CN 113222143B
Authority
CN
China
Prior art keywords
gradient
neural network
user
training
local
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110602892.2A
Other languages
English (en)
Other versions
CN113222143A (zh
Inventor
李登昊
王健宗
黄章成
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202110602892.2A priority Critical patent/CN113222143B/zh
Publication of CN113222143A publication Critical patent/CN113222143A/zh
Application granted granted Critical
Publication of CN113222143B publication Critical patent/CN113222143B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/20Information retrieval; Database structures therefor; File system structures therefor of structured data, e.g. relational data
    • G06F16/24Querying
    • G06F16/245Query processing
    • G06F16/2455Query execution
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F21/00Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
    • G06F21/60Protecting data
    • G06F21/62Protecting access to data via a platform, e.g. using keys or access control rules
    • G06F21/6218Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
    • G06F21/6245Protecting personal data, e.g. for financial or medical purposes
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • General Physics & Mathematics (AREA)
  • Bioethics (AREA)
  • Databases & Information Systems (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Software Systems (AREA)
  • Biophysics (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • Mathematical Physics (AREA)
  • Evolutionary Computation (AREA)
  • Biomedical Technology (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Medical Informatics (AREA)
  • Computer Hardware Design (AREA)
  • Computer Security & Cryptography (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明涉及人工智能技术领域,尤其涉及一种图神经网络训练方法、系统、计算设备及存储介质。该图神经网络训练方法应用在图神经网络训练系统中,包括多个用户端以及一个训练端;图神经网络训练方法包括:训练端初始化图神经网络的网络参数;用户端根据网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度;用户端生成非目标特征的随机伪梯度,并将随机伪梯度与真实梯度作为用户端对应的局部梯度发送至训练端;训练端对接收到的各用户端发送的局部梯度进行梯度聚合,得到聚合梯度,并将聚合梯度分发至对应的用户端,以使各用户端根据聚合梯度更新本地的图神经网络。该图神经网络训练方法可有保证用户个人隐私的安全性问题。

Description

图神经网络训练方法、系统、计算机设备及存储介质
技术领域
本发明涉及人工智能技术领域,尤其涉及一种图神经网络训练方法、系统、计算机设备及存储介质。
背景技术
近年来,在推荐系统中图神经网络技术的应用日益广泛。基于图神经网络对推荐系统进行建模方法一般是将用户和推荐项目均作为节点嵌入图中,用户和其感兴趣的项目之间连接。图结构中的每个节点的特征由一特征向量描述,在图神经网络的训练过程中该特征向量被不断优化直至收敛。
然而传统的图神经网络在训练时需要完整的图结构信息,在推荐系统中则需要训练方会收集各用户端的用户偏好信息作为训练样本,而这些信息中包含了用户的个人隐私,容易出现用户隐私泄露的安全性问题。
发明内容
本发明实施例提供一种图神经网络训练方法、系统、计算机设备及存储介质,以解决现有图神经网络训练流程中,无法保证用户个人隐私的安全性问题。
一种图神经网络训练方法,应用在图神经网络训练系统中;所述图神经网络训练系统包括多个用户端以及一个训练端;所述图神经网络训练方法包括如下步骤:
通过所述训练端初始化图神经网络的网络参数;其中,所述网络参数包括用户节点对应的第一特征、全局推荐项目对应的第二特征以及所述用户节点与对应的目标推荐项目之间的连接权值;所述目标推荐项目为所述全局推荐项目中与所述用户节点具有连接关系的推荐项目;所述第二特征包括目标推荐项目对应的目标特征以及非目标推荐项目对应的非目标特征;
通过所述训练端将所述网络参数发送至与所述用户节点对应的用户端,以使所述用户端根据所述网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度;其中,所述真实梯度包括所述第一特征、所述目标特征以及所述连接权值对应的真实梯度;
通过所述用户端生成所述非目标特征的随机伪梯度,并将所述随机伪梯度与所述真实梯度作为所述用户端对应的局部梯度发送至所述训练端;
通过所述训练端对接收到的各用户端发送的所述局部梯度进行梯度聚合,得到所述网络参数对应的聚合梯度,并将所述聚合梯度返回至对应的用户端,以使用户端根据所述聚合梯度更新本地的图神经网络。
一种图神经网络训练系统,包括:
初始化模块,用于通过所述训练端初始化图神经网络的网络参数;其中,所述网络参数包括用户节点对应的第一特征、全局推荐项目对应的第二特征以及所述用户节点与对应的目标推荐项目之间的连接权值;所述目标推荐项目为所述全局推荐项目中与所述用户节点具有连接关系的推荐项目;所述第二特征包括目标推荐项目对应的目标特征以及非目标推荐项目对应的非目标特征;
真实梯度计算模块,用于通过所述训练端将所述网络参数发送至与所述用户节点对应的用户端,以使所述用户端根据所述网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度;其中,所述真实梯度包括所述第一特征、所述目标特征以及所述连接权值对应的真实梯度;
局部梯度获取与发送模块,用于通过所述用户端生成所述非目标特征的随机伪梯度,并将所述随机伪梯度与所述真实梯度作为所述用户端对应的局部梯度发送至所述训练端;
梯度聚合与分发模块,用于通过所述训练端对接收到的各用户端发送的所述局部梯度进行梯度聚合,得到所述网络参数对应的聚合梯度,并将所述聚合梯度返回至对应的用户端,以使用户端根据所述聚合梯度更新本地的图神经网络。
一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述图神经网络训练方法的步骤。
一种计算机存储介质,所述计算机存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述图神经网络训练方法的步骤。
上述图神经网络训练方法、系统、计算机设备及存储介质中,通过训练端初始化图神经网络的网络参数,并将网络参数分发至对应的用户端,以便用户端根据网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度,由于真实梯度包括第一特征、目标特征对应的第二梯度以及连接权值等局部参数的梯度,故通过用户端生成非目标特征的随机伪梯度,并将随机伪梯度与真实梯度作为用户端对应的局部梯度发送至训练端,以对真实梯度进行数据混淆,有效保护用户的个人隐私,而用户真实的偏好信息保留在用户端本地,避免了在推荐系统模型中出现隐私泄露的风险。最后,通过训练端对接收到的各用户端发送的局部梯度进行梯度聚合,得到聚合梯度,并将聚合梯度分发至对应的用户端,以使各用户端根据聚合梯度更新本地的图神经网络,使得图神经网络优化的梯度是综合各用户端的梯度信息确定,从而在多方数据在不共享的秘密状态下,实现多方参与训练的目的。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本发明一实施例中图神经网络训练方法的一流程图;
图2是图1中步骤S12的一具体流程图;
图3是图2中步骤S23的一具体流程图;
图4是图2中步骤S24的一具体流程图;
图5是图1中步骤S13的一具体流程图;
图6是图1中步骤S14的一具体流程图;
图7是图2中步骤S22的一具体流程图;
图8是本发明一实施例中图神经网络训练系统的一示意图;
图9是本发明一实施例中计算机设备的一示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明所提供的图神经网络训练方法可应用在一种图神经网络训练系统中,该图神经网络训练系统中包括包括多个用户端以及一个训练端。其中,每一用户端用于训练本地的图神经网络;训练端用于汇总各用户端的梯度信息,并提供更新梯度给各用户端,以使各用户端根据该更新梯度更新本地的图神经网络,从而实现基于联邦学习的图神经网络训练方法,从而可在数据隐私得到保护的前提下训练图神经网络。
在一实施例中,如图1所示,该图神经网络训练方法具体包括如下步骤:
S11:通过训练端初始化图神经网络的网络参数;其中,网络参数包括用户节点对应的第一特征、全局推荐项目对应的第二特征以及用户节点与对应的目标推荐项目之间的连接权值;目标推荐项目为全局推荐项目中与用户节点具有连接关系的推荐项目;第二特征包括目标推荐项目对应的目标特征以及非目标推荐项目对应的非目标特征。
其中,本实施例中以推荐系统建立异构图为例,将用户和全局推荐项目节点嵌入图中,每个节点的特征分别以一个向量表示,节点之间的连接基于用户对项目的喜好构建,即异构图中所有边均建立在用户-目标推荐项目节点之间。针对这一特点,整个异构图结构可以基于用户被拆解为若干本地子图,即每一用户端对应一本地子图。每个子图中仅包含一个用户节点以及相同的多个项目节点(即全局推荐项目)。由于图中不存在连接两个不同用户的边,每个用户端的子图必然可以包含该用户节点的所有邻居节点。
其中,本地子图包括用户节点、全局推荐项目以及全局推荐项目中与用户节点具有连接关系的目标推荐项目,该本地子图中,用户节点由一对应的用户特征向量表示,全局推荐项目以一对应的项目特征向量表示;网络参数包括用户节点的第一特征、全局推荐项目的第二特征以及全局推荐项目中与用户节点对应的目标推荐项目之间的连接权值。
S12:通过训练端将网络参数发送至与用户节点对应的用户端,以使用户端根据网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度;其中,真实梯度包括第一特征、目标特征以及连接权值对应的真实梯度。
其中,通过训练端初始化图神经网络的网络参数,并根据网络参数中的用户节点分发至对应的用户端。
具体地,通过用户端根据网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度,即基于该网络参数初始化本地的图神经网络,然后将本地子图输入至该本地的图神经网络进行训练,以得到该轮训练中局部网络参数对应的真实梯度。该局部网络参数包括第一特征、目标特征以及连接权值。
可以理解地是,在计算各网络参数的梯度时,非推荐项目的梯度未被计算,仅计算局部网络参数的真实梯度,即反映用户喜好信息的目标推荐项目对应的真实梯度、连接权值对应的真实梯度以及用户节点对应的真实梯度。
S13:通过用户端生成非目标特征的随机伪梯度,并将随机伪梯度与真实梯度作为用户端对应的局部梯度发送至训练端。
其中,局部梯度表示用户端计算的各网络参数对应的梯度信息。该局部梯度包括用户节点对应第一特征的真实梯度、全局推荐项目的第二特征的真实梯度以及连接权值对应的真实梯度。该全局推荐项目的第二特征的梯度包括目标推荐项目对应的目标特征的真实梯度以及非目标推荐项目对应的非目标特征的随机伪梯度。
具体地,由于在计算时仅有与用户相关的目标推荐项目的特征向量计算了真实梯度,因此梯度信息中包含了用户感兴趣的目标推荐项目信息,故本实施例中通过在用户端生成随机伪梯度,并将该随机伪梯度作为非目标特征对应的梯度信息,将该随机伪梯度与真实梯度作为该用户端对应的局部梯度发送至训练端,以对真实梯度进行数据混淆,有效保护用户的个人隐私,而用户真实的偏好信息保留在用户端本地,避免了在推荐系统模型中出现隐私泄露的风险。
S14:通过训练端对接收到的各用户端发送的局部梯度进行梯度聚合,得到网络参数对应的聚合梯度,并将聚合梯度返回至对应的用户端,以使用户端根据聚合梯度更新本地的图神经网络。
本实施例中,训练端对接收到的各用户端发送的局部梯度进行梯度聚合,可包括但不限计算各局部梯度的累加和或对各用户端发送的局部梯度取平均值处理,以得到对应的聚合梯度。
具体地,通过训练端对接收到的各用户端发送的局部梯度进行梯度聚合,得到聚合梯度,并将聚合梯度分发至各用户端,以使各用户端根据聚合梯度更新本地的图神经网络,使得图神经网络优化的梯度是综合各用户端的梯度信息确定,从而在多方数据在不共享的秘密状态下,实现多方参与训练的目的。
本实施例中,通过训练端初始化图神经网络的网络参数,并将网络参数分发至对应的用户端,以便用户端根据网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度,由于真实梯度包括第一特征、目标特征对应的第二梯度以及连接权值等局部参数的梯度,故通过用户端生成非目标特征的随机伪梯度,并将随机伪梯度与真实梯度作为用户端对应的局部梯度发送至训练端,以对真实梯度进行数据混淆,有效保护用户的个人隐私,而用户真实的偏好信息保留在用户端本地,避免了在推荐系统模型中出现隐私泄露的风险。最后,通过训练端对接收到的各用户端发送的局部梯度进行梯度聚合,得到聚合梯度,并将聚合梯度分发至对应的用户端,以使各用户端根据聚合梯度更新本地的图神经网络,使得图神经网络优化的梯度是综合各用户端的梯度信息确定,从而在多方数据在不共享的秘密状态下,实现多方参与训练的目的。
在一实施例中,如图2所示,目标推荐项对应一真实标注值;即步骤S12中,即通过用户端根据网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度,具体包括如下步骤:
S21:基于网络参数初始化图神经网络。
具体地,根据该网络参数初始化本地的图神经网络,以保证各用户端对应的本地图神经网络基于同一初始化的网络参数进行训练,以保证多方训练的有效性和准确性。
S22:将本地子图输入至图神经网络进行预测,得到图神经网络输出的预测结果。
其中,将本地子图输入至图神经网络进行预测,得到图神经网络输出的预测结果,即将本地子图作为训练样本输入至图神经网络中进行特征的聚合处理,以得到图中每一特征节点对应的聚合后的特征向量,并将聚合后的用户节点对应的聚合特征与其具有连接关系的各目标推荐项目节点对应的特征做内积处理,计算相似度以作为不同目标推荐项目的推荐概率,即预测结果。该聚合处理即为将某特征节点与其具有连接关系的其他特征节点的特征通过连接权值进行聚合。
S23:根据预测结果以及真实标注值,计算网络损失。
其中,真实标注值可指用户节点与不同目标推荐项目间的评分(例如70分、80分)或分类评价(例如喜欢或不喜欢)。具体地,可根据当前图神经网络的预测任务的不同,采用不同的损失函数计算损失,若图神经网络的预测任务为分类任务,则调用交叉熵损失函数,根据预测结果以及真实标注值,计算网络损失;若图神经网络的预测任务为回归任务,则调用均方误差损失函数,根据预测结果以及真实标注值,计算网络损失。
S24:根据网络损失,计算真实梯度。
具体地,通过该网络损失以对不同的局部网络参数求偏导,即可获取局部网络参数对应的真实梯度。其中,局部网络参数包括用户节点对应的第一特征、目标推荐项目对应的目标特征以及连接权值。
在一实施例中,如图3所示,步骤S23中,即根据预测结果以及真实标注值,计算网络损失,具体包括如下步骤:
S31:若图神经网络的预测任务为分类任务,则调用交叉熵损失函数,根据预测结果以及真实标注值,计算网络损失。
具体地,若图神经网络的预测任务为分类任务,则调用交叉熵损失函数,根据预测结果以及真实标注值,计算网络损失,即通过如下公式计算网络损失,crossEntropy=∑k-yk*log(pk),其中,pk是预测用户对目标推荐项目评价为第k类的概率,即预测结果,当真实标注值,即用户实际对目标推荐项目的评价是第i类,那么yi=1,对于不等于i的k,yk均为0。
S32:若图神经网络的预测任务为回归任务,则调用均方误差损失函数,根据预测结果以及真实标注值,计算网络损失。
具体地,若图神经网络的预测任务为回归任务,则调用均方误差损失函数,根据预测结果以及真实标注值,计算网络损失,即通过如下公式计算网络损失,mean squarederror=∑(y-groundtruth)2,其中y为模型预测结果,例如网络预测的用户对各目标推荐项目的评价分数(例如70分、80分或90分),groundtruth即为真实标注值,用于表示用户实际对目标推荐项目的评价分数。
在一实施例中,如图4所示,步骤S24中,即根据网络损失,计算真实梯度,具体如下步骤:
S41:根据网络损失以及第一特征,计算第一特征对应的真实梯度。
S42:根据网络损失以及连接权值,计算连接权值对应的真实梯度。
S42:根据网络损失以及目标特征,计算目标特征对应的真实梯度。
具体地,通过基于网络损失分别对第一特征、连接权值以及目标特征计算偏导,即可获取真实梯度。以公式表示为:其中,l表示网络损失,θ表示第一特征、连接权值或目标特征。
在一实施例中,如图5所示,步骤S13中,即通过用户端生成非目标特征的随机伪梯度,具体包括如下步骤:
S51:根据随机函数,生成一组零均值的随机数。
S52:将零均值的随机数作为非目标特征对应的随机伪梯度。
具体地,由于后续在训练端对各局部梯度进行聚合操作,以综合多方梯度信息训练模型,故为保证随机伪梯度不影响后续梯度聚合的聚合结果,本实施例红中通过生成一组零均值的随机数,并将该零均值的随机数作为非目标特征对应的随机伪梯度,避免随机伪梯度对后续计算聚合梯度的影响。
在一实施例中,如图6所示,步骤S14中,即通过训练端对接收到的各用户端发送的局部梯度进行梯度聚合,得到聚合梯度,具体包括如下步骤:
S61:对各用户端发送的局部梯度进行累加处理,以将得到的累加结果作为聚合梯度。
S62:对各用户端发送的局部梯度进行取平均处理,以将得到的平均值作为聚合梯度。
本实施例中,训练端对接收到的各用户端发送的局部梯度进行梯度聚合,可包括但不限计算各局部梯度的累加和或对各用户端发送的局部梯度取平均值处理,以得到对应的聚合梯度。其中,该聚合梯度包括用户节点对应的第一特征的聚合梯度、目标推荐项目对应的目标特征的聚合梯度以及连接权值的聚合梯度。
在一实施例中,如图7所示,本地子图包括用户节点以及与用户节点具有连接关系的目标推荐项目;步骤S22中,即将本地子图输入至图神经网络进行预测,得到图神经网络输出的预测结果,具体包括如下步骤:
S71:基于连接权值,对用户节点以及目标推荐项目进行聚合更新,得到用户节点对应的第一聚合特征以及目标推荐项目对应的第二聚合特征。
S72:计算第一聚合特征以及第二聚合特征间的相似度,得到图神经网络输出的预测结果。
可以理解地是,图神经网络中的主要思想为通过对某节点具有连接关系的多个邻居节点的特征进行聚合,以更新该节点自身的信息表达,而不再仅针对自身特征进行网络训练,可解释性更强。
具体地,将本地子图输入至图神经网络进行预测,得到图神经网络输出的预测结果,即将本地子图作为训练样本输入至图神经网络中进行特征的聚合处理,以得到图中每一特征节点(即用户节点以及目标推荐项目)对应的聚合后的特征向量。
示例性地,假设本地子图中包括用户节点A,与其具有连接关系的目标推荐项目包括B、C和D,用户节点A与每一目标推荐项目之间的连接权值分别为a、b、c,此时A对应一特征向量比如(1,1,1),B对应一特征向量比如(2,2,2)、C对应一特征向量比如(3,3,3)以及D对应一特征向量比如(4,4,4),第一聚合特征即为该用户节点与每一目标推荐项目通过连接权值聚合,即Z=A+α*N,N=a*B+b*C+c*D=a*(2,2,2)+b*(3,3,3)+c*(4,4,4),α为一可调节参数,该参数的取值可自定义或通过注意力机制等方式获取,此处不做限定。
需要说明的是上述示例仅作举例说明,对于图神经网络模型中的具体计算也可根据实际情况进行调整,此处不做限定。
进一步地,将聚合后的用户节点对应的第一聚合特征与其具有连接关系的各目标推荐项目节点对应的第二聚合特征做内积处理,即A(第一聚合特征对应的向量)·B(第一聚合特征对应的向量),以计算余弦相似度以通过该余弦相似度衡量这两个特征向量间的相似性,作为不同目标推荐项目的推荐概率,即预测结果。该聚合处理即为将某特征节点与其具有连接关系的其他特征节点的特征通过连接权值进行聚合。
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
在一实施例中,提供一种图神经网络训练系统,该图神经网络训练系统与上述实施例中图神经网络训练方法一一对应。如图8所示,该图神经网络训练系统包括初始化模块10、真实梯度计算模块20、局部梯度获取与发送模块30以及梯度聚合与分发模块40。各功能模块详细说明如下:
初始化模块10,用于通过训练端初始化图神经网络的网络参数;其中,网络参数包括用户节点对应的第一特征、全局推荐项目对应的第二特征以及用户节点与对应的目标推荐项目之间的连接权值;目标推荐项目为全局推荐项目中与用户节点具有连接关系的推荐项目;第二特征包括目标推荐项目对应的目标特征以及非目标推荐项目对应的非目标特征。
真实梯度计算模块20,用于通过训练端将网络参数发送至与用户节点对应的用户端,以使用户端根据网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度;其中,真实梯度包括第一特征、目标特征以及连接权值对应的真实梯度。
局部梯度获取与发送模块30,用于通过用户端生成非目标特征的随机伪梯度,并将随机伪梯度与真实梯度作为用户端对应的局部梯度发送至训练端。
梯度聚合与分发模块40,用于通过训练端对接收到的各用户端发送的局部梯度进行梯度聚合,得到网络参数对应的聚合梯度,并将聚合梯度返回至对应的用户端,以使用户端根据聚合梯度更新本地的图神经网络。
具体地,局部梯度获取与发送模块包括网络初始化单元、预测单元、网络损失计算单元以及梯度计算单元。
网络初始化单元,用于基于网络参数初始化图神经网络。
预测单元,用于将本地子图输入至图神经网络进行预测,得到图神经网络输出的预测结果。
网络损失计算单元,用于根据预测结果以及真实标注值,计算网络损失。
梯度计算单元,用于根据网络损失,计算真实梯度。
具体地,网络损失计算单元包括第一计算单元和第二计算单元。
第一计算单元,用于若图神经网络的预测任务为分类任务,则调用交叉熵损失函数,根据预测结果以及真实标注值,计算网络损失。
第二计算单元,用于若图神经网络的预测任务为回归任务,则调用均方误差损失函数,根据预测结果以及真实标注值,计算网络损失。
具体地,梯度计算单元包括第一梯度计算单元、第二梯度计算单元以及第三梯度计算单元。
第一梯度计算单元,用于根据网络损失以及第一特征,计算第一特征对应的真实梯度。
第二梯度计算单元,用于根据网络损失以及连接权值,计算连接权值对应的真实梯度。
第三梯度计算单元,用于根据网络损失以及目标特征,计算目标特征对应的真实梯度。
具体地,局部梯度获取与发送模块包括随机生成单元和随机伪梯度获取单元。
随机生成单元,用于根据随机函数,生成一组零均值的随机数。
随机伪梯度获取单元,用于将零均值的随机数作为非目标特征对应的随机伪梯度。
具体地,梯度聚合与分发模块包括第一聚合单元以及第二聚合单元。
第一聚合单元,用于对各用户端发送的局部梯度进行累加处理,以将得到的累加结果作为聚合梯度;或者
第二聚合单元,用于对各用户端发送的局部梯度进行取平均处理,以将得到的平均值作为聚合梯度。
具体地,预测单元包括聚合更新子单元和预测结果计算子单元。
聚合更新子单元,用于基于连接权值,对用户节点以及目标推荐项目进行聚合更新,得到述用户节点对应的第一聚合特征以及目标推荐项目对应的第二聚合特征。
预测结果计算子单元,用于计算第一聚合特征以及第二聚合特征间的相似度,得到图神经网络输出的预测结果。
关于图神经网络训练系统的具体限定可以参见上文中对于图神经网络训练方法的限定,在此不再赘述。上述图神经网络训练系统中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图9所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括计算机存储介质、内存储器。该计算机存储介质存储有操作系统、计算机程序和数据库。该内存储器为计算机存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储执行图神经网络训练方法过程中生成或获取的数据,如候选查询数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种图神经网络训练方法。
在一个实施例中,提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现上述实施例中的图神经网络训练方法的步骤,例如图1所示的步骤S11-S14,或者图2至图7中所示的步骤。或者,处理器执行计算机程序时实现图神经网络训练系统这一实施例中的各模块/单元的功能,例如图8所示的各模块/单元的功能,为避免重复,这里不再赘述。
在一实施例中,提供一计算机存储介质,该计算机存储介质上存储有计算机程序,该计算机程序被处理器执行时实现上述实施例中图神经网络训练方法的步骤,例如图1所示的步骤S11-S14,或者图2至图7中所示的步骤,为避免重复,这里不再赘述。或者,该计算机程序被处理器执行时实现上述图神经网络训练系统这一实施例中的各模块/单元的功能,例如图8所示的各模块/单元的功能,为避免重复,这里不再赘述。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。
以上实施例仅用以说明本发明的技术方案,而非对其限制,尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。

Claims (10)

1.一种图神经网络训练方法,其特征在于,应用在图神经网络训练系统中,所述图神经网络训练系统包括多个用户端以及一个训练端;所述图神经网络训练方法包括:
通过所述训练端初始化图神经网络的网络参数;其中,所述网络参数包括用户节点对应的第一特征、全局推荐项目对应的第二特征以及所述用户节点与对应的目标推荐项目之间的连接权值;所述目标推荐项目为所述全局推荐项目中与所述用户节点具有连接关系的推荐项目;所述第二特征包括目标推荐项目对应的目标特征以及非目标推荐项目对应的非目标特征;
通过所述训练端将所述网络参数发送至与所述用户节点对应的用户端,以使所述用户端根据所述网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度;其中,所述真实梯度包括所述第一特征、所述目标特征以及所述连接权值对应的真实梯度;
通过所述用户端生成所述非目标特征的随机伪梯度,并将所述随机伪梯度与所述真实梯度作为所述用户端对应的局部梯度发送至所述训练端;
通过所述训练端对接收到的各用户端发送的所述局部梯度进行梯度聚合,得到所述网络参数对应的聚合梯度,并将所述聚合梯度返回至对应的用户端,以使用户端根据所述聚合梯度更新本地的图神经网络。
2.如权利要求1所述图神经网络训练方法,其特征在于,所述目标推荐项对应一真实标注值;所述通过所述用户端根据所述网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度,包括:
基于所述网络参数初始化所述图神经网络;
将所述本地子图输入至所述图神经网络进行预测,得到所述图神经网络输出的预测结果;
根据所述预测结果以及所述真实标注值,计算网络损失;
根据所述网络损失,计算所述真实梯度。
3.如权利要求2所述图神经网络训练方法,其特征在于,所述根据所述预测结果以及所述真实标注值,计算网络损失,包括:
若所述图神经网络的预测任务为分类任务,则调用交叉熵损失函数,根据所述预测结果以及所述真实标注值,计算网络损失;
若所述图神经网络的预测任务为回归任务,则调用均方误差损失函数,根据所述预测结果以及所述真实标注值,计算网络损失。
4.如权利要求3所述图神经网络训练方法,其特征在于,所述根据所述网络损失,计算所述真实梯度,包括:
根据所述网络损失以及所述第一特征,计算所述第一特征对应的真实梯度;
根据所述网络损失以及所述连接权值,计算所述连接权值对应的真实梯度;
根据所述网络损失以及所述目标特征,计算所述目标特征对应的真实梯度。
5.如权利要求4所述图神经网络训练方法,其特征在于,所述通过所述用户端生成所述非目标特征的随机伪梯度,包括:
根据随机函数,生成一组零均值的随机数;
将所述零均值的随机数作为所述非目标特征对应的随机伪梯度。
6.如权利要求5所述图神经网络训练方法,其特征在于,所述通过所述训练端对接收到的各用户端发送的所述局部梯度进行梯度聚合,得到聚合梯度,包括:
对所述各用户端发送的局部梯度进行累加处理,以将得到的累加结果作为所述聚合梯度;或者
对所述各用户端发送的局部梯度进行取平均处理,以将得到的平均值作为所述聚合梯度。
7.如权利要求2所述图神经网络训练方法,其特征在于,所述本地子图包括用户节点以及与所述用户节点具有连接关系的目标推荐项目;所述将所述本地子图输入至所述图神经网络进行预测,得到所述图神经网络输出的预测结果,包括:
基于所述连接权值,对所述用户节点以及所述目标推荐项目进行聚合更新,得到所述用户节点对应的第一聚合特征以及所述目标推荐项目对应的第二聚合特征;
计算所述第一聚合特征以及所述第二聚合特征间的相似度,得到所述图神经网络输出的预测结果。
8.一种图神经网络训练系统,其特征在于,包括:
初始化模块,用于通过训练端初始化图神经网络的网络参数;其中,所述网络参数包括用户节点对应的第一特征、全局推荐项目对应的第二特征以及所述用户节点与对应的目标推荐项目之间的连接权值;所述目标推荐项目为所述全局推荐项目中与所述用户节点具有连接关系的推荐项目;所述第二特征包括目标推荐项目对应的目标特征以及非目标推荐项目对应的非目标特征;
真实梯度计算模块,用于通过所述训练端将所述网络参数发送至与所述用户节点对应的用户端,以使所述用户端根据所述网络参数和预先构建的本地子图训练本地的图神经网络,得到真实梯度;其中,所述真实梯度包括所述第一特征、所述目标特征以及所述连接权值对应的真实梯度;
局部梯度获取与发送模块,用于通过所述用户端生成所述非目标特征的随机伪梯度,并将所述随机伪梯度与所述真实梯度作为所述用户端对应的局部梯度发送至所述训练端;
梯度聚合与分发模块,用于通过所述训练端对接收到的各用户端发送的所述局部梯度进行梯度聚合,得到所述网络参数对应的聚合梯度,并将所述聚合梯度返回至对应的用户端,以使用户端根据所述聚合梯度更新本地的图神经网络。
9.一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述图神经网络训练方法的步骤。
10.一种计算机存储介质,所述计算机存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述图神经网络训练方法的步骤。
CN202110602892.2A 2021-05-31 2021-05-31 图神经网络训练方法、系统、计算机设备及存储介质 Active CN113222143B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110602892.2A CN113222143B (zh) 2021-05-31 2021-05-31 图神经网络训练方法、系统、计算机设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110602892.2A CN113222143B (zh) 2021-05-31 2021-05-31 图神经网络训练方法、系统、计算机设备及存储介质

Publications (2)

Publication Number Publication Date
CN113222143A CN113222143A (zh) 2021-08-06
CN113222143B true CN113222143B (zh) 2023-08-01

Family

ID=77081780

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110602892.2A Active CN113222143B (zh) 2021-05-31 2021-05-31 图神经网络训练方法、系统、计算机设备及存储介质

Country Status (1)

Country Link
CN (1) CN113222143B (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114491590A (zh) * 2022-01-17 2022-05-13 平安科技(深圳)有限公司 基于联邦因子分解机的同态加密方法、系统、设备及存储介质
CN114491629B (zh) * 2022-01-25 2024-06-18 哈尔滨工业大学(深圳) 一种隐私保护的图神经网络训练方法及系统
CN114462600B (zh) * 2022-04-11 2022-07-05 支付宝(杭州)信息技术有限公司 一种有向图对应的图神经网络的训练方法及装置
CN117273086B (zh) * 2023-11-17 2024-03-08 支付宝(杭州)信息技术有限公司 多方联合训练图神经网络的方法及装置

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110929870A (zh) * 2020-02-17 2020-03-27 支付宝(杭州)信息技术有限公司 图神经网络模型训练方法、装置及系统
CN111985622A (zh) * 2020-08-25 2020-11-24 支付宝(杭州)信息技术有限公司 一种图神经网络训练方法和系统
WO2021082681A1 (zh) * 2019-10-29 2021-05-06 支付宝(杭州)信息技术有限公司 多方联合训练图神经网络的方法及装置

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021082681A1 (zh) * 2019-10-29 2021-05-06 支付宝(杭州)信息技术有限公司 多方联合训练图神经网络的方法及装置
CN110929870A (zh) * 2020-02-17 2020-03-27 支付宝(杭州)信息技术有限公司 图神经网络模型训练方法、装置及系统
CN111985622A (zh) * 2020-08-25 2020-11-24 支付宝(杭州)信息技术有限公司 一种图神经网络训练方法和系统

Also Published As

Publication number Publication date
CN113222143A (zh) 2021-08-06

Similar Documents

Publication Publication Date Title
CN113222143B (zh) 图神经网络训练方法、系统、计算机设备及存储介质
Rodríguez-Barroso et al. Federated Learning and Differential Privacy: Software tools analysis, the Sherpa. ai FL framework and methodological guidelines for preserving data privacy
Jia et al. Residual correlation in graph neural network regression
Saxena et al. D-GAN: Deep generative adversarial nets for spatio-temporal prediction
CN113221183B (zh) 实现隐私保护的多方协同更新模型的方法、装置及系统
CN110929047A (zh) 关注邻居实体的知识图谱推理方法和装置
CN111931076B (zh) 基于有权有向图进行关系推荐的方法、装置和计算机设备
CA3151805A1 (en) Automated path-based recommendation for risk mitigation
CN110210233B (zh) 预测模型的联合构建方法、装置、存储介质及计算机设备
Yoon et al. Robustifying sequential neural processes
CN113240505B (zh) 图数据的处理方法、装置、设备、存储介质及程序产品
JP7361928B2 (ja) 勾配ブースティングを介したプライバシーを守る機械学習
Livieris et al. A new conjugate gradient algorithm for training neural networks based on a modified secant equation
CN113822315A (zh) 属性图的处理方法、装置、电子设备及可读存储介质
US20220027793A1 (en) Dedicated artificial intelligence system
Varma et al. Legato: A layerwise gradient aggregation algorithm for mitigating byzantine attacks in federated learning
Tang et al. A factorization machine-based QoS prediction approach for mobile service selection
KR20220145380A (ko) 컨텐츠 배포 및 분석을 위한 프라이버시 보호 기계 학습
CN113761367A (zh) 机器人流程自动化程序的推送系统及方法、装置、计算设备
Faezi et al. Degan: Decentralized generative adversarial networks
Sun et al. Communication-efficient vertical federated learning with limited overlapping samples
CN113705797A (zh) 基于图神经网络的推荐模型训练方法、装置、设备及介质
CN113792892A (zh) 联邦学习建模优化方法、设备、可读存储介质及程序产品
Kumar et al. FIDEL: Fog integrated federated learning framework to train neural networks
CN113254996B (zh) 图神经网络训练方法、装置、计算设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant