CN115965079A - 图模型训练方法和装置 - Google Patents
图模型训练方法和装置 Download PDFInfo
- Publication number
- CN115965079A CN115965079A CN202211680435.6A CN202211680435A CN115965079A CN 115965079 A CN115965079 A CN 115965079A CN 202211680435 A CN202211680435 A CN 202211680435A CN 115965079 A CN115965079 A CN 115965079A
- Authority
- CN
- China
- Prior art keywords
- graph
- graph model
- trained
- training sample
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本说明书实施例提供了一种图模型训练方法和装置。待训练的图模型适用于第一业务场景中;该方法包括:得到教师网络图模型;其中,所述教师网络图模型适用于第二业务场景中且为训练完毕的图模型;从第一业务场景中得到训练样本;利用教师网络图模型与待训练的图模型分别学习训练样本;对根据教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,以得到差异损失;利用待训练的图模型对训练样本的学习结果,得到业务损失;根据所述差异损失以及所述业务损失,调整所述待训练的图模型的模型参数。本说明书实施例能够减少对第一业务场景中训练样本的数量的要求。
Description
技术领域
本说明书一个或多个实施例涉及人工智能技术,尤其涉及图模型训练方法和装置。
背景技术
图(Graph)是用于表示对象之间关联关系的一种抽象数据结构,使用节点(Vertex)和边(Edge)进行描述,其中,节点表示对象,边表示对象之间的关系。图中的每一个节点都具有自己的各种特征,每一个边也具有自己的各种特征。例如,参见图1所示,图中的节点可以为用户的账户信息,边可以为各个用户之间的交易行为,那么,每一个节点包括的特征可以为涉及到账户的所有特征,比如账户ID、人群、相关用户的性别、年龄、学历、账户信息、资产信息、历史交易习惯等各种信息,而每一个边包括的特征可以为涉及到一个交易的所有特征,比如交易ID、交易发生的时间、交易发生的地点、金额、支付渠道、交易的性质比如是否属于违规交易等。
目前,在很多业务场景下,例如资金交易风控系统,商品推荐系统,都需要基于对象之间的关系来进行分析与处理,而利用神经网络的图模型作为专门针对这类非结构化数据的处理方法,在近些年来也是被广泛使用。然而作为机器模型,图模型在训练阶段,需要依赖数量巨大的样本,通过海量的样本才能训练出性能良好的图模型。
然而在很多业务场景中,比如新兴领域,样本的数量是有限的,无法得到满足训练要求的数量的样本。进一步地,如果采用海量训练样本,那么在构图关联和特征清洗中,往往占据了整个图模型上线流程中的大部分时间,消耗了工程技术人员的绝大部分精力。
因此,需要一种新的图模型训练方法,从而不再依赖海量的训练样本。
发明内容
本说明书一个或多个实施例描述了图模型的训练方法和装置,能够减少对训练样本的数量的要求。
根据第一方面,提供了一种图模型训练方法,其中,该待训练的图模型适用于第一业务场景中;该方法包括:
得到教师网络图模型;其中,所述教师网络图模型适用于第二业务场景中且为训练完毕的图模型;
从第一业务场景中得到训练样本;
利用教师网络图模型与待训练的图模型分别学习训练样本;
针对教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,以得到差异损失;
利用待训练的图模型对训练样本的学习结果,以得到业务损失;
根据所述差异损失以及所述业务损失,调整所述待训练的图模型的模型参数。
其中,所述利用教师网络图模型与待训练的图模型分别学习训练样本,包括:
利用教师网络图模型从训练样本中提取图结构,以得到第一图结构表征;
利用待训练的图模型从训练样本中提取图结构,以得到第二图结构表征;
利用教师网络图模型从训练样本中提取图特征,以得到第一图特征表征;
利用待训练的图模型从训练样本中提取图特征,以得到第二图特征表征;
对应地,所述针对根据教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束得到差异损失,包括:
对第一图结构表征与第二图结构表征进行相似性约束,以得到第一差异损失;
对第一图特征表征与第二图特征表征进行相似性约束,以得到第二差异损失。
其中,该方法进一步包括:
根据所述训练样本得到邻居关系矩阵;其中,所述邻居关系矩阵为N*N的矩阵,N为根据训练样本得到的节点的数量;对于任意两个节点,如果该两个节点直连,形成一阶邻居关系,则在所述邻居关系矩阵中对应于该两个节点的矩阵元素的值为1,否则为0;
根据所述训练样本得到特征矩阵;其中,所述特征矩阵为N*M的矩阵,M为根据训练样本得到的每一个节点包括的特征的数量;所述特征矩阵中的每一行对应一个节点,该行中不同的矩阵元素表示该节点的不同的特征。
其中,该方法进一步包括:
将所述特征矩阵中的所有矩阵元素的值均设置为0,以得到第一特征矩阵;
所述利用教师网络图模型从训练样本中提取图结构,包括:
将所述邻居关系矩阵、所述第一特征矩阵输入所述教师网络图模型,以得到该教师网络图模型输出的第一图结构表征;
所述利用待训练的图模型从训练样本中提取图结构,包括:
将所述邻居关系矩阵、所述第一特征矩阵输入所述待训练的图模型,以得到该待训练的图模型输出的第二图结构表征。
其中,该方法进一步包括:
针对所述特征矩阵中的每一个矩阵元素,将该矩阵元素的值设置为该矩阵元素所对应节点的对应特征的值,以得到第二特征矩阵;
所述利用教师网络图模型从训练样本中提取图特征,包括:
将所述邻居关系矩阵、所述第二特征矩阵输入所述教师网络图模型,以得到该教师网络图模型输出的第一图特征表征;
所述利用待训练的图模型从训练样本中提取图特征,包括:
将所述邻居关系矩阵、所述第二特征矩阵输入所述待训练的图模型,以得到该待训练的图模型输出的第二图特征表征。
其中,所述对第一图结构表征与第二图结构表征进行相似性约束得到第一差异损失,包括:
利用计算均方误差的方法对第一图结构表征与第二图结构表征求相似,以得到所述第一差异损失;
所述对第一图特征表征与第二图特征表征进行相似性约束得到第二差异损失,包括:
利用计算均方误差的方法对第一图特征表征与第二图特征表征求相似,以得到所述第二差异损失。
其中,所述利用待训练的图模型对训练样本的学习结果得到业务损失,包括:
使用多层感知机对所述第二图特征表征进行推理,从而将所述第二图特征表征转换为0至1之间的预测数值;
利用所述预测数值与训练样本包括的标签信息计算二分类损失,从而得到所述业务损失。
其中,所述根据所述差异损失以及所述业务损失调整所述待训练的图模型的参数,包括:
设置所述待训练的图模型的约束函数为:L=loss0+α·loss1+β·loss2;
其中,L表示约束函数;loss0表示所述业务损失;loss1表示所述第一差异损失;loss2表示所述第二差异损失;α、β均为预先设置的超参约束,α、β的值越大,则借鉴教师网络图模型的知识的权重比例越大。
根据第二方面,提供了一种图模型训练装置,其中,该待训练的图模型适用于第一业务场景中;该装置包括:
教师网络获取模块,配置为得到教师网络图模型;其中,所述教师网络图模型适用于第二业务场景中且为训练完毕的图模型;
训练样本得到模块,配置为从第一业务场景中得到训练样本;
学习模块,配置为利用教师网络图模型与待训练的图模型分别学习训练样本;
差异损失得到模块,配置为对根据教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,以得到差异损失;
业务损失得到模块,配置为利用待训练的图模型对训练样本的学习结果,得到业务损失;
模型参数调整模块,配置为根据所述差异损失以及所述业务损失,调整所述待训练的图模型的模型参数。
根据第三方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现本说明书任一实施例所述的方法。
本说明书实施例提供的图模型的训练方法及装置,当需要训练适用于第一业务场景中的图模型时,会使用在其他业务场景中已经训练出的图模型(称为教师网络图模型)来对第一业务场景中的训练样本进行学习,因为对教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,从而得到差异损失,因此,可以实现将教师网络图模型根据训练样本学习到的知识蒸馏到适用于第一业务场景的待训练的图模型中。即,对新场景的图模型进行一定的知识蒸馏,这样不仅可以降低对第一业务场景中训练样本的数量要求,即降低所需要准备的图数据的训练样本的量级,而且可以使用到相关联的知识,提升待训练的图模型的性能。因此,本说明书实施例实际上提出了一种针对图模型的知识蒸馏方案,利用已有的历史图模型,降低训练样本的量级,提升待训练的图模型的性能。
附图说明
为了更清楚地说明本说明书实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本说明书的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是图网络的示意图。
图2是本说明书一个实施例中图模型的训练方法的流程图。
图3是本说明书一个实施例中利用教师网络图模型与待训练的图模型分别学习训练样本并进行相似性约束的方法流程图。
图4是本说明书一个实施例中邻居关系矩阵的一种示意图。
图5是本说明书一个实施例中特征矩阵的一种示意图。
图6是本说明书一个实施例中图模型的训练方法的示意图。
图7是本说明书一个实施例中图模型的训练装置的结构示意图。
图8是本说明书另一个实施例中图模型的训练装置的结构示意图。
具体实施方式
下面结合附图,对本说明书提供的方案进行描述。
首先需要说明的是,在本发明实施例中使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本发明。在本发明实施例和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。
应当理解,本文中使用的术语“和/或”仅仅是一种描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。另外,本文中字符“/”,一般表示前后关联对象是一种“或”的关系。
图2是本说明书一个实施例中图模型的训练方法的流程图。该方法的执行主体为图模型的训练装置。可以理解,该方法也可以通过任何具有计算、处理能力的装置、设备、平台、设备集群来执行。参见图2,需要训练适用于第一业务场景的图模型;该方法包括:
步骤201:得到教师网络图模型;其中,教师网络图模型适用于第二业务场景中且为训练完毕的图模型。
步骤203:从第一业务场景中得到训练样本。
步骤205:利用教师网络图模型与待训练的图模型分别学习训练样本。
步骤207:针对教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,得到差异损失。
步骤209:利用待训练的图模型对训练样本的学习结果,得到业务损失。
步骤211:根据所述差异损失以及所述业务损失,调整所述待训练的图模型的模型参数。
根据上述图2所示的流程可以看出,在本说明书实施例中,当需要训练适用于第一业务场景中的图模型时,会使用在其他业务场景中已经训练出的图模型(称为教师网络图模型)来对第一业务场景中的训练样本进行学习,因为对教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,从而得到差异损失,因此,可以实现将教师网络图模型根据训练样本学习到的知识蒸馏到适用于第一业务场景的待训练的图模型中。即,对新场景的图模型进行一定的知识蒸馏,这样不仅可以降低对第一业务场景中训练样本的数量要求,即降低所需要准备的图数据训练样本的量级,而且可以使用到相关联的知识,提升待训练的图模型的性能。因此,本说明书实施例实际上提出了一种针对图模型的知识蒸馏方案,利用已有的历史图模型,降低训练样本的量级,提升待训练的图模型的性能。
下面结合附图及具体的例子对上述图2所示的流程进行说明。
首先对于步骤201:得到教师网络图模型;其中,教师网络图模型适用于第二业务场景中且为训练完毕的图模型。
在本说明书实施例中,当需要在一个业务场景记为第一业务场景中训练出图模型时,可能会出现该第一业务场景中训练样本不足等问题。比如,对于一个新兴的业务场景,由于历史业务数据不足,因此能够用于训练图模型的训练样本的数量也非常有限,因此,无法训练出适用于该新兴业务场景的性能优异的图模型。
而对于其他业务场景,比如任意一个第二业务场景,因为其历史业务数据充足,已经使用足够的训练样本训练出了适用于该第二业务场景的图模型。虽然适用于该第二业务场景的图模型无法直接适用于第一业务场景中,但是因为都是图模型,因此,需要学习、训练的方式是相同的,都是需要学习到图数据中各个对象所形成的关系(该关系可以称为图结构)以及各个对象的特征。因此,本说明书实施例中,考虑使用知识蒸馏的方式,将适用于第二业务场景的教师网络图模型学习到的第一业务场景中的训练样本的特征,以知识蒸馏的方式蒸馏到适用于第一业务场景的待训练的图模型中。
如前所述,教师网络图模型是适用于第二业务场景中且为训练完毕的图模型。在本说明书实施例中,教师网络图模型不限于任何特定的网络结构,可以是适用于任一业务场景中的图聚合网络。当然,为了进一步提高知识蒸馏的效果,教师网络图模型所适用的第二业务场景可以与待训练的图模型适用的第一业务场景相近,即业务领域相近。比如风险控制业务已经是比较成熟的业务,在风险控制业务场景中,已经存在训练好的、性能优异的图模型,而对于比如网商业务场景或者其他新兴的业务场景,其训练样本较少,因此,当需要训练适用于网商业务场景或者其他新兴的业务场景的图模型时,可以使用风险控制业务场景中的图模型作为教师网络图模型。
接下来对于步骤203:从第一业务场景中得到训练样本。
第一业务场景是待训练的图模型所适用的业务场景,比如上述的网商业务场景或者其他新兴的业务场景。训练样本可以是根据第一业务场景的历史业务数据所得到的。可以理解,训练样本是带有标签的图数据,从图数据中能够获取到节点、边、特征、通过边所形成的节点之间的连接关系。
接下来对于步骤205:利用教师网络图模型与待训练的图模型分别学习训练样本。
本步骤205是让教师网络图模型与待训练的图模型分别学习训练样本,即学习同一个目标信息,从而后续可以让待训练的图模型借鉴教师网络图模型的学习结果。
对于一个图网络,其核心的内容包括图结构,即各个节点之间的连接关系,比如,节点1与哪一个节点直接相连,节点1是否能通过某一个节点与节点5相连。因此,在学习训练样本时,需要根据第一业务场景中得到的训练样本学习能够体现第一业务场景特点的图结构,即网络拓扑。
同时,对于一个图网络,其核心的内容还包括图特征,即图网络中每一个对象的特征。比如,在节点1与节点2之间连接有边,即节点1与节点2为一阶邻居时,节点1会具有何种特征、节点2会具有何种特征。因此,在学习训练样本时,需要根据第一业务场景中得到的训练样本学习能够体现第一业务场景特点的图特征。
可见,在本步骤205中需要从训练样本中同时提取出图结构的表征以及图特征的表征。相应地,在本说明书一个实施例中,参见图3、图6,步骤205的具体实现过程包括步骤2051至步骤2057:
步骤2051:利用教师网络图模型从训练样本中提取图结构,以得到第一图结构表征。
步骤2053:利用待训练的图模型从训练样本中提取图结构,以得到第二图结构表征。
步骤2055:利用教师网络图模型从训练样本中提取图特征,以得到第一图特征表征。
步骤2057:利用待训练的图模型从训练样本中提取图特征,以得到第二图特征表征。
在本说明书实施例中,可以利用邻居关系矩阵A以及特征矩阵B来实现上述步骤2051至步骤2057的过程。下面说明如何利用该两个矩阵来实现。
如前所述,需要根据第一业务场景中得到的训练样本学习能够体现第一业务场景特点的图结构,即网络拓扑。在本说明书一个实施例中,可以通过邻居关系矩阵记为A,来体现训练样本中各个节点之间的邻居关系,从而体现图结构即网络拓扑。需要说明的是,在本说明书实施例中,邻居关系指的是两个节点之间通过边直连,形成一阶的邻居关系。根据第一业务场景中得到的训练样本,在邻居关系矩阵A中,第一行的各个矩阵元素表示训练样本中的节点1与训练样本中的各个节点之间是否具有邻居关系;第二行的各个矩阵元素表示训练样本中的节点2与训练样本中的各个节点之间是否具有邻居关系,以此类推,第N行的各个矩阵元素表示训练样本中的节点N与训练样本中的各个节点之间是否具有邻居关系。
因此,在本说明书一个实施例中,该方法进一步包括:根据训练样本得到邻居关系矩阵;其中,邻居关系矩阵为N*N的矩阵,N为训练样本中包括的节点的数量;对于任意两个节点,如果该两个节点直连,形成一阶邻居关系,则在所述邻居关系矩阵中对应于该两个节点的矩阵元素的值为1,否则为0。比如,参见图4,A11表示的是节点1与节点1之间的邻居关系,比如在交易业务场景中,节点1为一个用户账户,该用户账户给自己发过红包或者转过帐,则节点1与节点1之间具有邻居关系,A11的值为1;A12表示节点1与节点2之间的邻居关系,比如节点1与节点2之间不具有邻居关系,则A12的值为0;A13表示节点1与节点3之间的邻居关系,比如节点1与节点3之间不具有邻居关系,则A12的值为0,以此类推。
如前所述,需要根据第一业务场景中得到的训练样本学习能够体现第一业务场景特点的图特征。在本说明书一个实施例中,可以通过特征矩阵记为B,来体现训练样本中每一个节点的各个特征。因此,在本说明书一个实施例中,该方法进一步包括:根据训练样本得到特征矩阵;其中,特征矩阵为N*M的矩阵,M为训练样本中每一个节点包括的特征的数量;特征矩阵中的每一行对应一个节点,该行中不同的矩阵元素表示该节点的不同的特征。在具体实现时,会让所有节点的特征的数量/维度相同,如果一个节点在某一个维度上没有特征,则可以将该节点在该维度上的特征值置为空,从而保证每一个节点都有M个/维特征值。比如,参见图5,根据第一业务场景中得到的训练样本,在特征矩阵B中,B11对应节点1(比如节点1可以为一个账户)的第1个特征(第1个特征是第一个维度的特征比如金额),B12对应节点1的第2个特征(第2个特征是第二个维度的特征比如交易时间),以此类推,B1M对应节点1的第M个特征(第M个特征是第M个维度的特征比如用户名称);同理,B21对应节点2(比如节点2可以为另一个账户)的第1个特征(即第一个维度的特征比如金额),B22对应节点2的第2个特征(即第二个维度的特征比如交易时间),以此类推。
在步骤2051及步骤2053中,需要提取出图结构的表征。因为图结构只跟邻居关系矩阵A相关,跟特征矩阵B不相关,也就是说,训练样本中节点的特征值不能影响图结构的表征,因此,参见图6,将特征矩阵B中的所有矩阵元素的值均设置为0,此时得到的特征矩阵记为第一特征矩阵;
相应地,参见图6,步骤2051中,利用教师网络图模型从训练样本中提取图结构的过程包括:将形成的邻居关系矩阵、第一特征矩阵输入教师网络图模型,从而得到该教师网络图模型输出的第一图结构表征;
相应地,参见图6,步骤2053中,利用待训练的图模型从训练样本中提取图结构的过程包括:将形成的邻居关系矩阵、第一特征矩阵输入待训练的图模型,得到该待训练的图模型输出的第二图结构表征。
可以得到,第一图结构表征与第二图结构表征不依赖于训练样本的任何特征,即不依赖于节点的特征及边的特征,仅仅依赖作为训练样本的图数据的邻居信息。
在步骤2055及步骤2057中,需要提取出图特征的表征。因为图特征既跟邻居关系矩阵A相关,即需要知道节点之间的邻居关系,也跟特征矩阵B相关,即需要知道每一个节点的特征值,因此,参见图6,针对特征矩阵B中的每一个矩阵元素,将该矩阵元素的值设置为该矩阵元素所对应节点的对应特征的特征值,此时得到的特征矩阵记为第二特征矩阵。比如如前所述,矩阵元素B11的值为节点1的第1个特征(第一个维度的特征比如金额)的特征值,比如为归一化的特征值0.5;矩阵元素B12的值为节点1的第2个特征(第二个维度的特征比如交易时间)的特征值,比如为归一化的特征值0.7;矩阵元素BNM的值为节点N的第M个特征(第M个维度的特征比如用户名称)的特征值,比如为归一化的特征值0.3。
相应地,参见图6,步骤2055中,利用教师网络图模型从训练样本中提取图特征的过程可以包括:将形成的邻居关系矩阵、第二特征矩阵输入教师网络图模型,得到该教师网络图模型输出的第一图特征表征;
相应地,参见图6,步骤2057中,利用待训练的图模型从训练样本中提取图结构的过程包括:将形成的邻居关系矩阵、第二特征矩阵输入待训练的图模型,得到该待训练的图模型输出的第二图特征表征。
接下来对于步骤207:针对教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,以得到差异损失。
本步骤207的目的是为了约束两个图模型即教师网络图模型与待训练的图模型所分别提取出的表征尽可能相似,从而实现将教师网络图模型学习到的知识蒸馏到待训练的图模型中。
对应于上述图3所示的步骤205的实现过程,参见图3、图6,本步骤207的具体实现过程包括步骤2071及步骤2073:
步骤2071:对第一图结构表征与第二图结构表征进行相似性约束,从而得到第一差异损失,记为loss1。
本步骤2071的一种实现过程包括:利用计算均方误差的方法对第一图结构表征与第二图结构表征求相似,以得到第一差异损失。
比如,在本步骤2071中,第一图结构表征通常是一个一维数组,第二图结构表征也是一个一维数组,比如均为1*128维的数组。用均方误差的方法把两个一维数组求相似,即,在每一个维度上,两个数组中的值先求一个差值再对该差值求平方,得到在该维度上的差值的平方,最后得到128个差值的平方;把这128个差值的平方相加后再除以128,最后得到的结果就是第一差异损失,记为loss1,第一差异损失是一个标量值。
步骤2073:对第一图特征表征与第二图特征表征进行相似性约束,得到第二差异损失,记为loss2。
本步骤2073的一种实现过程包括:利用计算均方误差的方法对第一图特征表征与第二图特征表征求相似,以得到第二差异损失。
本步骤2073的实现方法可以参考对步骤2071的描述。
接下来对于步骤209:利用待训练的图模型对训练样本的学习结果,得到业务损失。
本步骤209的一种实现过程包括:
步骤2091:使用多层感知机对第二图特征表征进行推理,从而将第二图特征表征转换为0至1之间的预测数值。
步骤2093:利用预测数值与训练样本包括的标签信息计算二分类损失,从而得到业务损失。
第二图特征表征与训练样本的标签是对应的,也就是说,第二图特征表征是待训练的图模型针对训练样本所预测出来的业务结果,而标签是根据历史数据得到的训练样本的正确业务结果。需要将第二图特征表征与标签进行对比,来调整待训练图模型的模型参数。但是,图特征表征是一个一维数组比如为1*128维度的数组,而训练样本的标签是0-1之间的概率值,因此,需要执行步骤2091的处理,通过多层感知机对第二图特征表征进行推理,从而将第二图特征表征转换为0至1之间的预测数值,该预测数值表示概率值,这样,才能让第二图特征表征与标签都是0-1之间的数值,从而才能执行步骤2093的处理,即,计算二分类损失,从而得到业务损失,业务损失记为loss0。
上述步骤205和步骤207是通过教师网络图模型与待训练的图模型(待训练的图模型也可以称为学生网络图模型)的学习结果的差异,来得到差异损失,从而利用该差异损失调整待训练的图模型的模型参数。而本步骤209的过程则是利用待训练的图模型自己的学习结果,比如学习结果是否正确,来得到业务损失,从而利用该业务损失调整待训练的图模型的模型参数。
也就是说,步骤205的目的是实现图结构表征的知识蒸馏。步骤207的目的是实现图特征表征的知识蒸馏。步骤209的目的是第一业务场景的任务约束。通过步骤205、步骤207及步骤209配合的实现过程,来联合训练适用于第一业务场景的图模型。
接下来对于步骤211:根据差异损失以及业务损失,调整待训练的图模型的模型参数。
如前所述,差异损失包括:损失函数loss1,loss1体现教师网络图模型对训练样本的图结构表征进行知识蒸馏;损失函数loss2,loss2体现教师网络图模型对训练样本的图特征表征进行知识蒸馏。业务损失为loss0。
因此,参见图6,本步骤211中,可以设置待训练的图模型的约束函数为:L=loss0+α·loss1+β·loss2;
其中,L表示约束函数;loss0表示业务损失;loss1表示第一差异损失;loss2表示第二差异损失;α、β均为预先设置的超参约束,α、β的值越大,则借鉴教师网络图模型的知识的权重比例越大。α、β的值可以根据实际的业务需要来调整。
在本步骤211中,设置待训练的图模型的约束函数为:L=loss0+α·loss1+β·loss2之后,则可以根据该约束函数调整待训练的图模型的模型参数,之后进行下一轮的训练,即之后循环执行步骤203至步骤211,直至待训练的图模型收敛。
在本说明书的一个实施例中,提供了一种图模型的训练装置,参见图7,该装置包括:
教师网络获取模块701,配置为得到教师网络图模型;其中,所述教师网络图模型适用于第二业务场景中且为训练完毕的图模型;
训练样本得到模块702,配置为从第一业务场景中得到训练样本;
学习模块703,配置为利用教师网络图模型与待训练的图模型分别学习训练样本;
差异损失得到模块704,配置为对根据教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,以得到差异损失;
业务损失得到模块705,配置为利用待训练的图模型对训练样本的学习结果,得到业务损失;
模型参数调整模块706,配置为根据所述差异损失以及所述业务损失,调整所述待训练的图模型的模型参数。
在本说明书装置的一个实施例中,学习模块703被配置为执行:
利用教师网络图模型从训练样本中提取图结构,以得到第一图结构表征;
利用待训练的图模型从训练样本中提取图结构,以得到第二图结构表征;
利用教师网络图模型从训练样本中提取图特征,以得到第一图特征表征;
利用待训练的图模型从训练样本中提取图特征,以得到第二图特征表征;
对应地,差异损失得到模块704被配置为执行:
对第一图结构表征与第二图结构表征进行相似性约束,以得到第一差异损失;
对第一图特征表征与第二图特征表征进行相似性约束,以得到第二差异损失。
在本说明书装置的一个实施例中,进一步包括:训练样本处理模块801;
训练样本处理模块801,配置为根据所述训练样本得到邻居关系矩阵;其中,所述邻居关系矩阵为N*N的矩阵,N为根据训练样本得到的节点的数量;对于任意两个节点,如果该两个节点直连,形成一阶邻居关系,则在所述邻居关系矩阵中对应于该两个节点的矩阵元素的值为1,否则为0;以及,根据所述训练样本得到特征矩阵;其中,所述特征矩阵为N*M的矩阵,M为根据训练样本得到的每一个节点包括的特征的数量;所述特征矩阵中的每一行对应一个节点,该行中不同的矩阵元素表示该节点的不同的特征。
在本说明书装置的一个实施例中,训练样本处理模块801被配置为执行:将特征矩阵中的所有矩阵元素的值均设置为0,以得到第一特征矩阵;
相应地,学习模块703被配置为执行:
将所述邻居关系矩阵、所述第一特征矩阵输入所述教师网络图模型,以得到该教师网络图模型输出的第一图结构表征;
将所述邻居关系矩阵、所述第一特征矩阵输入所述待训练的图模型,以得到该待训练的图模型输出的第二图结构表征。
在本说明书装置的一个实施例中,训练样本处理模块801被配置为执行:对所述特征矩阵中的每一个矩阵元素,将该矩阵元素的值设置为该矩阵元素所对应节点的对应特征的值,以得到第二特征矩阵;
相应地,学习模块703被配置为执行:
将所述邻居关系矩阵、所述第二特征矩阵输入所述教师网络图模型,以得到该教师网络图模型输出的第一图特征表征;
将所述邻居关系矩阵、所述第二特征矩阵输入所述待训练的图模型,以得到该待训练的图模型输出的第二图特征表征。
在本说明书装置的一个实施例中,差异损失得到模块704被配置为执行:
利用计算均方误差的方法对第一图结构表征与第二图结构表征求相似,以得到所述第一差异损失;
利用计算均方误差的方法对第一图特征表征与第二图特征表征求相似,以得到所述第二差异损失。
在本说明书装置的一个实施例中,业务损失得到模块705被配置为执行:
使用多层感知机对所述第二图特征表征进行推理,从而将所述第二图特征表征转换为0至1之间的预测数值;
利用所述预测数值与训练样本包括的标签信息计算二分类损失,从而得到所述业务损失。
在本说明书装置的一个实施例中,模型参数调整模块706被配置为执行:设置待训练的图模型的约束函数为:L=loss0+α·loss1+β·loss2;
其中,L表示约束函数;loss0表示所述业务损失;loss1表示所述第一差异损失;loss2表示所述第二差异损失;α、β均为预先设置的超参约束,α、β的值越大,则借鉴教师网络图模型的知识的权重比例越大。
需要说明的是,上述各装置通常实现于服务器端,可以分别设置于独立的服务器,也可以其中部分或全部装置的组合设置于同一服务器。该服务器可以是单个的服务器,也可以是由多个服务器组成的服务器集群,服务器可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品。上述各装置还可以实现于具有较强计算能力的计算机终端。
本说明书一个实施例提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行说明书中任一个实施例中的方法。
本说明书一个实施例提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现执行说明书中任一个实施例中的方法。
可以理解的是,本说明书实施例示意的结构并不构成对本说明书实施例的装置的具体限定。在说明书的另一些实施例中,上述装置可以包括比图示更多或者更少的部件,或者组合某些部件,或者拆分某些部件,或者不同的部件布置。图示的部件可以以硬件、软件或者软件和硬件的组合来实现。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本发明所描述的功能可以用硬件、软件、挂件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本发明的保护范围之内。
Claims (10)
1.图模型训练方法,其中,该待训练的图模型适用于第一业务场景中;该方法包括:
得到教师网络图模型;其中,所述教师网络图模型适用于第二业务场景中且为训练完毕的图模型;
从第一业务场景中得到训练样本;
利用教师网络图模型与待训练的图模型分别学习训练样本;
针对教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,以得到差异损失;
利用待训练的图模型对训练样本的学习结果,以得到业务损失;
根据所述差异损失以及所述业务损失,调整所述待训练的图模型的模型参数。
2.根据权利要求1所述的方法,其中,所述利用教师网络图模型与待训练的图模型分别学习训练样本,包括:
利用教师网络图模型从训练样本中提取图结构,以得到第一图结构表征;
利用待训练的图模型从训练样本中提取图结构,以得到第二图结构表征;
利用教师网络图模型从训练样本中提取图特征,以得到第一图特征表征;
利用待训练的图模型从训练样本中提取图特征,以得到第二图特征表征;
对应地,所述针对根据教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束得到差异损失,包括:
对第一图结构表征与第二图结构表征进行相似性约束,以得到第一差异损失;
对第一图特征表征与第二图特征表征进行相似性约束,以得到第二差异损失。
3.根据权利要求2所述的方法,其中,该方法进一步包括:
根据所述训练样本得到邻居关系矩阵;其中,所述邻居关系矩阵为N*N的矩阵,N为根据训练样本得到的节点的数量;对于任意两个节点,如果该两个节点直连,形成一阶邻居关系,则在所述邻居关系矩阵中对应于该两个节点的矩阵元素的值为1,否则为0;
根据所述训练样本得到特征矩阵;其中,所述特征矩阵为N*M的矩阵,M为根据训练样本得到的每一个节点包括的特征的数量;所述特征矩阵中的每一行对应一个节点,该行中不同的矩阵元素表示该节点的不同的特征。
4.根据权利要求3所述的方法,其中,该方法进一步包括:
将所述特征矩阵中的所有矩阵元素的值均设置为0,以得到第一特征矩阵;
所述利用教师网络图模型从训练样本中提取图结构,包括:
将所述邻居关系矩阵、所述第一特征矩阵输入所述教师网络图模型,以得到该教师网络图模型输出的第一图结构表征;
所述利用待训练的图模型从训练样本中提取图结构,包括:
将所述邻居关系矩阵、所述第一特征矩阵输入所述待训练的图模型,以得到该待训练的图模型输出的第二图结构表征。
5.根据权利要求3所述的方法,其中,该方法进一步包括:
针对所述特征矩阵中的每一个矩阵元素,将该矩阵元素的值设置为该矩阵元素所对应节点的对应特征的值,以得到第二特征矩阵;
所述利用教师网络图模型从训练样本中提取图特征,包括:
将所述邻居关系矩阵、所述第二特征矩阵输入所述教师网络图模型,以得到该教师网络图模型输出的第一图特征表征;
所述利用待训练的图模型从训练样本中提取图特征,包括:
将所述邻居关系矩阵、所述第二特征矩阵输入所述待训练的图模型,以得到该待训练的图模型输出的第二图特征表征。
6.根据权利要求2所述的方法,其中,所述对第一图结构表征与第二图结构表征进行相似性约束得到第一差异损失,包括:
利用计算均方误差的方法对第一图结构表征与第二图结构表征求相似,以得到所述第一差异损失;
所述对第一图特征表征与第二图特征表征进行相似性约束得到第二差异损失,包括:
利用计算均方误差的方法对第一图特征表征与第二图特征表征求相似,以得到所述第二差异损失。
7.根据权利要求2所述的方法,其中,所述利用待训练的图模型对训练样本的学习结果得到业务损失,包括:
使用多层感知机对所述第二图特征表征进行推理,从而将所述第二图特征表征转换为0至1之间的预测数值;
利用所述预测数值与训练样本包括的标签信息计算二分类损失,从而得到所述业务损失。
8.根据权利要求2所述的方法,其中,所述根据所述差异损失以及所述业务损失调整所述待训练的图模型的参数,包括:
设置所述待训练的图模型的约束函数为:L=loss0+α·loss1+β·loss2;
其中,L表示约束函数;loss0表示所述业务损失;loss1表示所述第一差异损失;loss2表示所述第二差异损失;α、β均为预先设置的超参约束,α、β的值越大,则借鉴教师网络图模型的知识的权重比例越大。
9.图模型训练装置,其中,该待训练的图模型适用于第一业务场景中;该装置包括:
教师网络获取模块,配置为得到教师网络图模型;其中,所述教师网络图模型适用于第二业务场景中且为训练完毕的图模型;
训练样本得到模块,配置为从第一业务场景中得到训练样本;
学习模块,配置为利用教师网络图模型与待训练的图模型分别学习训练样本;
差异损失得到模块,配置为对根据教师网络图模型对训练样本的学习结果与待训练的图模型对训练样本的学习结果进行相似性约束,以得到差异损失;
业务损失得到模块,配置为利用待训练的图模型对训练样本的学习结果,得到业务损失;
模型参数调整模块,配置为根据所述差异损失以及所述业务损失,调整所述待训练的图模型的模型参数。
10.一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-8中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211680435.6A CN115965079A (zh) | 2022-12-27 | 2022-12-27 | 图模型训练方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211680435.6A CN115965079A (zh) | 2022-12-27 | 2022-12-27 | 图模型训练方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115965079A true CN115965079A (zh) | 2023-04-14 |
Family
ID=87359564
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211680435.6A Pending CN115965079A (zh) | 2022-12-27 | 2022-12-27 | 图模型训练方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115965079A (zh) |
-
2022
- 2022-12-27 CN CN202211680435.6A patent/CN115965079A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Jaafra et al. | Reinforcement learning for neural architecture search: A review | |
WO2023065545A1 (zh) | 风险预测方法、装置、设备及存储介质 | |
Mousavi et al. | Traffic light control using deep policy‐gradient and value‐function‐based reinforcement learning | |
WO2017206936A1 (zh) | 基于机器学习的网络模型构造方法及装置 | |
CN107330731B (zh) | 一种识别广告位点击异常的方法和装置 | |
CN111695415A (zh) | 图像识别模型的构建方法、识别方法及相关设备 | |
CN111079532A (zh) | 一种基于文本自编码器的视频内容描述方法 | |
CN112069903B (zh) | 基于深度强化学习实现人脸识别端边卸载计算方法及装置 | |
CN112862092B (zh) | 一种异构图卷积网络的训练方法、装置、设备和介质 | |
CN111898703B (zh) | 多标签视频分类方法、模型训练方法、装置及介质 | |
US20220237917A1 (en) | Video comparison method and apparatus, computer device, and storage medium | |
CN112053327B (zh) | 视频目标物检测方法、系统及存储介质和服务器 | |
CN113706151A (zh) | 一种数据处理方法、装置、计算机设备及存储介质 | |
CN113822315A (zh) | 属性图的处理方法、装置、电子设备及可读存储介质 | |
CN114091667A (zh) | 一种面向非独立同分布数据的联邦互学习模型训练方法 | |
Milutinovic et al. | End-to-end training of differentiable pipelines across machine learning frameworks | |
CN111402156B (zh) | 一种涂抹图像的复原方法、装置及存储介质和终端设备 | |
Pedronette et al. | Rank-based self-training for graph convolutional networks | |
CN116402352A (zh) | 一种企业风险预测方法、装置、电子设备及介质 | |
CN114863092A (zh) | 一种基于知识蒸馏的联邦目标检测方法及系统 | |
CN113987236B (zh) | 基于图卷积网络的视觉检索模型的无监督训练方法和装置 | |
CN111079930A (zh) | 数据集质量参数的确定方法、装置及电子设备 | |
CN113326884A (zh) | 大规模异构图节点表示的高效学习方法及装置 | |
CN113705402A (zh) | 视频行为预测方法、系统、电子设备及存储介质 | |
CN116541779A (zh) | 个性化公共安全突发事件检测模型训练方法、检测方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |