CN117033992A - 一种分类模型的训练方法及装置 - Google Patents
一种分类模型的训练方法及装置 Download PDFInfo
- Publication number
- CN117033992A CN117033992A CN202210462062.9A CN202210462062A CN117033992A CN 117033992 A CN117033992 A CN 117033992A CN 202210462062 A CN202210462062 A CN 202210462062A CN 117033992 A CN117033992 A CN 117033992A
- Authority
- CN
- China
- Prior art keywords
- node
- isomorphic
- nodes
- graph
- same type
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 128
- 238000013145 classification model Methods 0.000 title claims abstract description 97
- 238000012549 training Methods 0.000 title claims abstract description 77
- 238000013528 artificial neural network Methods 0.000 claims abstract description 71
- 239000013598 vector Substances 0.000 claims abstract description 70
- 238000012512 characterization method Methods 0.000 claims abstract description 53
- 238000009826 distribution Methods 0.000 claims abstract description 25
- 230000006870 function Effects 0.000 claims description 30
- 238000010586 diagram Methods 0.000 claims description 14
- 238000013507 mapping Methods 0.000 claims description 14
- 230000007246 mechanism Effects 0.000 claims description 14
- 238000004590 computer program Methods 0.000 claims description 7
- 238000002372 labelling Methods 0.000 claims description 7
- 239000011159 matrix material Substances 0.000 claims description 7
- 230000002776 aggregation Effects 0.000 claims description 6
- 238000004220 aggregation Methods 0.000 claims description 6
- 238000003860 storage Methods 0.000 claims description 6
- 230000001360 synchronised effect Effects 0.000 claims description 6
- 238000006243 chemical reaction Methods 0.000 claims description 5
- 230000000875 corresponding effect Effects 0.000 description 31
- 239000000203 mixture Substances 0.000 description 15
- 238000005457 optimization Methods 0.000 description 10
- 238000012545 processing Methods 0.000 description 10
- 230000000694 effects Effects 0.000 description 9
- 201000001197 subcortical band heterotopia Diseases 0.000 description 8
- 238000004891 communication Methods 0.000 description 7
- 238000004364 calculation method Methods 0.000 description 6
- 230000002596 correlated effect Effects 0.000 description 6
- 238000002474 experimental method Methods 0.000 description 6
- 230000004927 fusion Effects 0.000 description 5
- 230000009286 beneficial effect Effects 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 3
- 230000006872 improvement Effects 0.000 description 3
- 230000008569 process Effects 0.000 description 3
- 230000009466 transformation Effects 0.000 description 3
- 230000001133 acceleration Effects 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 2
- 238000010276 construction Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000001514 detection method Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 230000001502 supplementing effect Effects 0.000 description 2
- 238000013459 approach Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 150000001875 compounds Chemical class 0.000 description 1
- 230000007812 deficiency Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000007499 fusion processing Methods 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 239000012633 leachable Substances 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 238000005065 mining Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000007500 overflow downdraw method Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 230000002195 synergetic effect Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本申请提供一种分类模型的训练方法,包括构建多个包括异构视图和同构视图的对比视图,将异构视图输入异构图神经网络,得到目标类型的节点的第一表征向量;将目标类型的节点的第一表征向量输入分类器,得到目标类型的节点的预测类别的概率分布;基于目标类型的节点的预测类别的概率分布和目标类型的节点的标签,得到监督损失;将同构视图输入同构图神经网络,得到目标类型的节点的第二表征向量;基于正样本对的相似度和负样本对的相似度,确定对比损失;以最小化监督损失和对比损失为目标,调整异构图神经网络的参数和分类器的参数,以得到训练后的分类模型。
Description
技术领域
本申请涉及人工智能技术领域,尤其涉及一种分类模型的训练方法及装置。
背景技术
现实世界数据集中的实体分类问题往往涉及多种类型的实体及其关系。例如,电子商务中的交易欺诈检测,就是根据交易、买家、邮寄地址、付款方式等信息,对一个交易进行欺诈\安全的二分类。类似的异质数据还广泛存在于学术数据库、影视资料库、医疗数据库等等场景下。目前,业界一种常用的思路是将异质数据抽象为异构图,从而将实体分类问题转化为图上的节点分类问题。
一般来说,各种异构图神经网络模型是以端到端的方式进行训练的,需要大量的、多种多样的标记数据用于不同的下游任务。然而,在大多数现实世界的场景中,丰富的标记数据通常是昂贵的,甚至是不可行的。例如,医疗数据、生物化学实验数据往往样本量有限、采集过程漫长;交易欺诈检测应用中的欺诈交易在所有交易中只占很小比例;对论文、学者的研究领域的判断需要专业知识,成本较高。
而在标记数据数量较少的情况下,利用一般的监督学习的训练方法训练出的模型对节点的预测分类的准确性较低,预测性能较差,很难满足使用要求。
发明内容
本申请的实施例提供一种分类模型的训练方法,利用对比学习的方法在少标签的条件下,较好的完成了分类模型的训练,实现训练后的分类模型具有较好的预测精度,满足使用要求。
第一方面,本申请提供了一种分类模型的训练方法,分类模型包括异构图神经网络和分类器,分类模型用于预测异构图中同一类型节点的标签的概率分布,该方法包括:获取数据集,数据集包括异构图和异构图中部分节点的标签,部分节点的类型相同;将异构图输入分类模型,以最小化监督损失函数和对比损失函数为目标,调整异构图神经网络的参数和分类器的参数,以得到训练后的分类模型;其中,监督损失函数指示部分节点预测的标签的概率分布和部分节点的标签之间的关系;对比损失函数指示正样本对的相似度与负样本对的相似度之间的关系,正样本对包括第一节点和第二节点,负样本对包括第一节点和第三节点,第一节点是异构图中的节点,第二节点和第三节点是与异构图对应的同构图中的节点,同构图中的节点的类型与部分节点的类型相同,第一节点和第二节点具有相同的类型和相同的标签,第一节点和第三节点具有相同的类型和不同的标签。
也就是说,正样本对和负样本对是针对于异构图中的同一个节点,形成的两组数据,例如,对于节点i,与节点i标签相同的节点j,与节点i标签不同的节点z,节点i、节点j和节点z均为同一类型的节点,则正样本对为(异构图中的节点i,同构图中的节点j);负样本对为(异构图中的节点i,同构图中的节点z)。
可以理解的是,异构图神经网络是可以处理异构图的图神经网络,也就是说,异构图神经玩网络是针对异构图进行视图编码,得到异构图中同一类型节点中的各个节点的嵌入向量,各个节点的嵌入向量表征各个节点的特征。
为了方便表述,下文中也可将同一类型节点称之为目标类型的节点,目标类型的节点为异构图节点分类问题中待分类的节点类型的全部节点。
本申请提供了一种分类模型的训练方法,异构图表征目标类型的节点与其他类型节点间的相关关系,同构图表征目标类型的节点间的相关关系,构建同构图作为异构图的对比视图,有益于下游的分类任务(目标类型的节点间的关系对于下游分类任务很重要),在构建对比样本对(正样本对和负样本对)时,考虑标签信息,对比学习任务和监督学习任务进行联合训练,统一了对比学习和下游分类任务的优化目标,有效提升了节点分类的准确度。
在一个可能的实现中,还包括,将同构图作为同构图神经网络的输入,输出同构图的节点中各个节点的第二表征向量;正样本对的相似度基于第一节点的第一表征向量和第二节点的第二表征向量的相似度得到,负样本对的相似度基于第一节点的第一表征向量和第三节点的第二表征向量的相似度得到;其中,同一类型节点中各个节点的第一表征向量基于异构图作为异构图神经网络的输入,异构图神经网络输出得到。
在另一个可能的实现中,同构图包括M个同构子图,M为大于1的正整数;将同构图作为同构图神经网络的输入,输出同构图的节点中各个节点的第二表征向量,包括:将M个同构子图分别输入其各自对应的同构图神经网络,得到M个同构子图中各个同构子图的视图表征;将各个同构子图的视图表征进行融合,得到同构图的视图表征,同构图的视图表征包括同构图的节点中的各个节点的第二表征向量。
在该可能的实现中,构建多个同构图充分表征目标类型的节点的特征语义关系以及各种高阶关系,充分挖掘了异构图中目标类型的节点间不同的关系,将多个同构子图进行融合,充分发挥了对比学习的作用,从而有益于下游分类任务。
在另一个可能的实现中,将各个同构子图的视图表征进行融合,得到同构图的视图表征,包括:基于各个同构子图的权重,对各个同构子图的表征进行加权聚合,得到同构图的视图表征,各个同构子图的权重与各个同构子图的注意力分数有关。
在另一个可能的实现中,各个同构子图的注意力分数基于注意力机制对各个同构子图的转换视图表征进行注意力运算得到;各个同构子图的转换视图表征为将各个同构子图的视图表征通过映射矩阵,映射至公共特征空间得到。
在该可能的实现中,通过设置注意力机制,计算各个同构子图的注意力分数,以各个同构子图的注意力分数为权重将各个同构子图的视图表征进行加权融合得到融合后的同构图的视图表征,以表征各个目标类型的节点的各种语义关系及各种高阶关系。
在该可能的实现中,M个同构子图包括第一同构子图和第二同构子图;第一同构子图基于同一类型节点的节点间的相似度构建;第二同构子图基于同一类型节点对应的元路径构建。
在该可能的实现中,公开了两种同构子图的构建方法,一种为,特征相似度构图法构建第一同构子图,表征节点特征语义关系的同构图;另一种为,元路径关系构图法构建第二同构子图,构建表征目标类型的节点间高阶关系的同构图。
在另一个可能的实现中,M个同构子图基于同一类型节点对应的多条不同的元路径构建。
在另一个可能的实现中,一种分类模型的训练方法还包括:基于同一类型节点中各个节点的预测标签的概率分布,得到同一类型节点中各个节点的预测置信度;将预测置信度达标的同一类型节点以预测标签标注标签。
可以将标注有标签的节点集合称为标签样本集,在该可能的实现中,基于目标类型的节点中各个节点的预测置信度,扩充标签样本集,以弥补标注样本不足的问题。
在一个示例中,将预测置信度大于或等于阈值的同一类型节点确定为预测置信度达标的所述同一类型节点。
在另一个可能的实现中,阈值基于部分节点对应的预测置信度均值得到;预测置信度均值表征部分节点中的各个节点对应的预测置信度的平均值。
在该可能的实现中,利用已标注的标签样本集,从无标签样本中挑选出同类样本,并补充进标签样本集,自适应的扩充标签样本集,弥补了标注数据不足的问题。
在另一个可能的实现中,对比损失与正样本对的相似度正相关,与负样本对的相似度负相关,以使正样本对之间的特征更加相似,使负样本对之间的特征更加的不相似。
第二方面,本申请提供了一种分类方法,包括:获取异构图,异构图包括多种类型的节点;将异构图输入分类模型,得到同一类型节点中各个节点的预测标签;其中,分类模型基于第一方面所述的方法训练得到。
采用本申请第一方面提供的分类模型的训练方法训练得到的分类模型,有效提升了异构图中各个目标类型的节点的标签预测准确度。
第三方面,本申请提供了一种分类模型的练装置,所述分类模型包括异构图神经网络和分类器,所述分类模型用于预测异构图中同一类型节点的标签的概率分布,所述训练装置包括:
获取模块,用于获取数据集,所述数据集包括异构图和所述异构图中部分节点的标签,所述部分节点的类型相同;
调参模块,用于将所述异构图输入所述分类模型,以最小化监督损失函数和对比损失函数为目标,调整所述异构图神经网络的参数和所述分类器的参数,以得到训练后的所述分类模型;
其中,所述监督损失函数指示所述部分节点预测的标签的概率分布和所述部分节点的标签之间的关系;所述对比损失函数指示正样本对的相似度与负样本对的相似度之间的关系,所述正样本对包括第一节点和第二节点,所述负样本对包括所述第一节点和第三节点,所述第一节点是异构图中的节点,所述第二节点和所述第三节点是与所述异构图对应的所述同构图中的节点,所述同构图中的节点的类型与所述部分节点的类型相同,所述第一节点和所述第二节点具有相同的类型和相同的标签,所述第一节点和所述第三节点具有相同的类型和不同的标签。
在一个可能的实现中,还包括第二表征向量确定模块,
用于将所述同构图作为同构图神经网络的输入,输出所述同构图的节点中各个节点的第二表征向量;
所述正样本对的相似度基于所述第一节点的第一表征向量和第二节点的第二表征向量的相似度得到,所述负样本对的相似度基于所述第一节点的第一表征向量和第三节点的第二表征向量的相似度得到;其中,所述同一类型节点中各个节点的第一表征向量基于所述异构图作为所述异构图神经网络的输入,所述异构图神经网络输出得到。
在另一个可能的实现中,所述同构图包括M个同构子图,所述M为大于1的正整数;
所述第二表征向量确定模块还用于:
将所述M个同构子图分别输入其各自对应的同构图神经网络,得到所述M个同构子图中各个同构子图的视图表征;
将所述各个同构子图的视图表征进行融合,得到所述同构图的视图表征,所述同构图的视图表征包括所述同构图的节点中的各个节点的第二表征向量。
在另一个可能的实现中,所述将各个同构子图的视图表征进行融合,得到所述同构图的视图表征,包括:
基于所述各个同构子图的权重,对所述各个同构子图的表征进行加权聚合,得到所述同构图的视图表征,所述各个同构子图的权重与各个同构子图的注意力分数有关。
在另一个可能的实现中,所述各个同构子图的注意力分数基于注意力机制对所述各个同构子图的转换视图表征进行注意力运算得到;
所述各个同构子图的转换视图表征为将所述各个同构子图的视图表征通过映射矩阵,映射至公共特征空间得到。
在另一个可能的实现中,所述M个同构子图包括第一同构子图和第二同构子图;
所述第一同构子图基于所述同一类型节点的节点间的相似度构建;
所述第二同构子图基于所述同一类型节点对应的元路径构建。
在另一个可能的实现中,所述M个同构子图基于所述同一类型节点对应的多条不同的元路径构建。
在另一个可能的实现中,还包括标签样本集扩充模块,还包括:基于所述同一类型节点中各个节点的预测标签的概率分布,得到所述同一类型节点中各个节点的预测置信度;
将所述预测置信度达标的所述同一类型节点以所述预测标签标注标签。
在另一个可能的实现中,将预测置信度大于或等于阈值的所述同一类型节点确定为所述预测置信度达标的所述同一类型节点。
在另一个可能的实现中,所述阈值基于所述部分节点对应的预测置信度均值得到;
所述预测置信度均值表征所述部分节点中的各个节点对应的预测置信度的平均值
第四方面,本申请提供了一种分类装置,包括:
获取模块,用于获取异构图信息,所述异构图信息包括多种类型的节点和表征节点间的相关关系的连接边,所述多种类型的节点包括目标类型的节点;
推理模块,用于将所述异构图信息输入分类模型,得到所述目标类型的节点中各个节点的预测类别;
其中,所述分类模型基于第一方面所述的方法训练得到。
第五方面,本申请实施例提供一种计算设备,包括存储器和处理器,所述存储器中存储有指令,当所述指令被处理器执行时,使得第一方面和/或第二方面所述的方法被实现。
第六方面,本申请实施例提供一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在被处理器执行时,使得第一方面和/或第二方面所述的方法被实现。
第七方面,本申请实施例还提供一种计算机程序或计算机程序产品,该计算机程序或计算机程序产品包括指令,当所述指令执行时,令计算机执行第一方面和/或第二方面所述的方法。
第八方面,本申请实施例还提供一种芯片,包括至少一个处理器和通信接口,所述处理器用于执行第一方面和/或第二方面所述的方法。
附图说明
图1为异构图对比学习的通用框架示意图;
图2a为本申请实施例提供的分类模型的训练方法的框架示意图;
图2b为本申请实施例提供的一种分类模型的训练方法的流程图;
图3为本申请实施例提供的一种分类模型的训练方法的流程图;
图4为以元路径构建同构子视图的示意图;
图5为本申请实施例提供的分类方法的流程图;
图6为一种关于电影分类任务的异构图;
图7为本申请实施例提供的一种分类模型的训练装置的结构示意图;
图8为本申请实施例提供的一种分类装置的结构示意图;
图9为本申请实施例提供的计算设备的结构示意图。
具体实施方式
下面通过附图和实施例,对本申请的技术方案做进一步的详细描述。
针对数据的标签不足问题,可采用半监督学习方法来克服,半监督学习方法可大致分为两类:生成式学习和对比学习。其中,对比学习适用于分类任务。
对比学习挖掘数据内部结构,在不同的视图下最大化特征一致性来学习嵌入,异构图对比学习的通用框架可参见图1。
第一种异构图的对比学习方案为:对比GNN在异构图上的预训练策略(contrastive pre-training of GNNs on Heterogeneous graphs,CPT-HG),原始异构图输入异构图神经网络(heterogeneous graph neural network,HGN)得到每个节点的表征,在局部采样正负样本对计算对比损失。
关系预训练任务分为两部分。首先,CPT-HG希望保留节点之间的关系语义,为此设计了一个考虑边类型R的对比损失函数,如果原图中存在边(u,R,v),则u,v的表征作为关系R下的正样本对,而与u通过其他关系R′连接的节点v′的表征作为负样本对。其次,CPT-HG希望区分图的连接结构,如果图中u,v两节点通过任意边连接,则它们的表征为正样本对,而不以任何类型的边连接的节点u,v′的表征为负样本对。
子图预训练任务的目标是保留异构图中的高阶语义语境。它对于任意节点u抽取其周围的元图(metagraph)结构p,聚合p中的节点表征,得到节点u的高阶语义语境表征f(u,p)。f(u,p)与节点u的表征作为正样本对,其他节点的f(u’,p)与节点u的表征作为负样本对,进行对比。
但是CPT-HG仍存在着一些问题,例如,CPT-HG直接使用原始异构图作为单一视图,没有充分发挥对比学习寻求不同视图下特征一致性的效果,挖掘异构图信息的能力有限;CPT-HG采样局部连接的节点对、节点与局部元图语境作为对比样本,强调局部结构中的特征一致性,最高阶的关系抽取仍局限在一个元图中,无法建模在图中跨度更广的节点关系;采取预训练+微调的模式利用标签信息,使得对比损失与下游任务优化目标不一致等问题。
第二种方案为:具有协同对比学习的自监督异构图神经网络模型(self-supervised heterogeneous graph neural network with co-contrastive learning,HeCo),构建了网络模式视图和元路径视图,在网络模式视图下,HeCo使用注意力机制聚合异构图上节点u的一阶邻居(由边直接相连的)生成u的表征。在元路径视图下,HeCo使用注意力机制聚合异构图上与节点u的一阶元路径邻居(由元路径相连的)生成u的表征。
如果两个节点由足够多的元路径相连,它们是一对正样本,否则为一对负样本,使用一个样本的网络模式表征与另一个样本的元路径表征计算对比损失。
但是HeCo仍然存在着一些问题,例如,HeCo分别通过注意力机制聚合一阶邻居信息、一阶元路径邻居信息作为对比视图,不仅忽略了节点特征相似度蕴含的语义关系,也无法直接建模目标节点间的更高阶语义关系。其次,HeCo选取有多条元路径连接的节点为正样本对,与CPT-HG类似,它采取预训练+微调的模式利用标签信息,没有拉近同类别样本的距离,使得对比损失与下游任务优化目标不一致等问题。
本申请实施例针对上述问题,提供一种分类模型的训练方法,构建多个对比视图充分表征各种关系,以充分挖掘异构图中目标类型的节点的不同关系,在构建对比样本对时考虑标签信息,以统一了对比学习与下游任务的优化目标,在节点分类准确度上达到了明显优于业界一流方案的效果。
图2a为本申请实施例提供的分类模型的训练方法的框架图。由异构图Ghe和一个有标签样本集合L作为输入,输出训练好的分类模型的模型参数。如图2a所示,在模块1中构建多个对比视图,原始异构图刻画了节点局部连接情况,往往包含了目标类型的节点与其他类型节点的各种交互。而目标类型的节点Vtar之间还存在着特征语义关系、由元路径定义的高阶关系等。本方法构建多个同构视图G1,G2…Gm,以分别表达目标节点间的多种关系。
在模块2中对多个对比视图进行视图编码,将原始异构图输入HGN模型(包括但不限于SimpleHGN等)得到目标类型的节点的异构视图表征Hhe。同构视图分别输入同构图神经网络(graph neural network,GNN)模型(包括但不限于同构图卷积神经网络(graphconvolutional neural Network,GCN)模型等)得到目标类型的节点的同构表征H1,H2…Hm,并进一步进行视图表征融合,得到目标类型的节点的同构视图表征Hho。
在模块3中,计算监督损失(即图2a中的Supervised loss),异构视图表征输入分类器(例如,线性层),得到目标类型的节点的归一化分类预测分数同时,对于i∈L用Pi与真实标签计算有监督损失。
在模块4中,计算对比损失(即图2a中的Contrastive loss),对于所有目标节点i∈Vtar,利用Pi计算节i的标签预测置信度zi来表示对预测标签的确信程度。不断将标签预测置信度高的节点加入可信已标注集合根据/>为节点i构建正样本候选集Ci计算对比损失。
以最小化监督损失和对比损失为目标,调整异构图神经网络和分类器的参数,得到训练完成的分类模型。
下面通过流程图,详细介绍本申请实施例提供的分类模型的训练方法的详细实现方案。
图2b为本申请实施例提供的一种分类模型的训练方法的流程图。该方法可以通过任何具有计算能力的装置、设备、平台或设备集群来执行,例如服务器,该服务器可以具有AI芯片,例如包括图形处理器(graphics processing unit,GPU),加速处理器(Accelerated Processing Units,APU),嵌入式神经网络处理器(Neural-NetworkProcessing Unit,NPU)等。本申请对执行该方法的具体计算设备不做具体限定,可根据需要选择合适的计算设备执行。如图2b所示,该分类模型训练方法,至少包括步骤S201-S202。
在步骤S201中,获取数据集。
本申请实施例提供的分类模型的训练方法,针对训练的分类模型包括异构图神经网络和分类器,分类模型用于预测异构图中同一类型节点的标签的概率分布。
可以理解的是,异构图神经网络是可以处理异构图的图神经网络,也就是说,异构图神经玩网络是针对异构图进行视图编码,得到异构图中同一类型节点中的各个节点的嵌入向量,各个节点的嵌入向量表征各个节点的特征。
为了方便表述,下文中也可将同一类型节点称之为目标类型的节点,目标类型的节点为异构图节点分类问题中待分类的节点类型的全部节点。
数据集包括异构图和异构图中部分节点的标签,部分节点的类型相同。也就是说,数据集包括用于训练的异构图,以及具有预先标注的标签的部分节点集合,该节点集合可以称之为标签样本集L。
在步骤S202中,将异构图输入分类模型,以最小化监督损失函数和对比损失函数为目标,调整异构图神经网络的参数和分类器的参数,以得到训练后的分类模型。
其中,对比损失函数指示正样本对的相似度与负样本对的相似度之间的关系,正样本对包括第一节点和第二节点,负样本对包括第一节点和第三节点,第一节点是异构图中的节点,第二节点和第三节点是与异构图对应的同构图中的节点,同构图中的节点的类型与部分节点的类型相同,第一节点和第二节点具有相同的类型和相同的标签,第一节点和第三节点具有相同的类型和不同的标签。
本申请实施例提供的分类模型的训练方法,异构图表征目标类型的节点与其他类型节点间的相关关系,同构图表征目标类型的节点间的相关关系,构建同构图作为异构图的对比视图,有益于下游的分类任务(目标类型的节点间的关系对于下游分类任务很重要),在构建对比样本对(正样本对和负样本对)时,考虑标签信息,对比学习任务和监督学习任务进行联合训练,统一了对比学习和下游分类任务的优化目标,有效提升了节点分类的准确度。
下面详细介绍本申请实施例提供的一种分类模型的训练方法的详细实现。
图3为本申请实施例提供的一种分类模型的训练方法的流程图。该方法可以通过任何具有计算能力的装置、设备、平台或设备集群来执行,例如服务器,该服务器可以具有AI芯片,例如包括图形处理器(graphics processing unit,GPU),加速处理器(Accelerated Processing Units,APU),嵌入式神经网络处理器(Neural-NetworkProcessing Unit,NPU)等。本申请对执行该方法的具体计算设备不做具体限定,可根据需要选择合适的计算设备执行。如图3所示,该分类模型训练方法,至少包括步骤S301-S308。
在步骤S301中,获取训练数据集,训练数据集包括异构图和标签样本集。
异构图包括多种类型的节点,和表征节点间的相关关系的连接边,多种类型的节点包括目标类型的节点,目标类型的节点Vtar为异构图节点分类问题中待分类的节点类型的全部节点。
容易理解的是,异构图属于图结构的一种,由于其包含了不同类型的节点和连接边,节点可以表示真实世界中的各种实体,连接边表示各个节点之间的关系,它可以自然地表示许多真实世界的数据集。以图6中的影视数据集的异构图为例,该异构图包括多个不同类型的节点,例如电影(Movie)主题节点m1、m2、m3,演员(Actor)节点a1、a2、a3,导演(Director)节点d1、d2;各个节点之间具有连接边,表示两个节点之间具有相关关系,例如节点a1和节点m1之间具有连接边,表示演员a1参演了电影m1。
标签样本集L包括目标类型的节点中的部分节点,部分节点具有预先标注的标签,也就是说,标签样本集为已标注的目标类型的节点集合。
获取训练数据集的方式可以为用户输入,或者为从训练数据库中调取等方式。
在步骤S302中,基于异构图信息构建多个对比视图,对比视图包括异构图视图和同构视图。
异构视图包括多种类型的节点,以及表征节点间相关关系的连接边,多种类型中包括目标类型,可以将原始异构图直接作为异构视图,异构视图刻画了节点局部连接情况,包含了目标类型的节点与其他类型节点的各种交互关系信息。
同构视图包括目标类型的节点,以及表征目标类型的节点间相关关系的连接边。同构视图刻画了目标类型的节点间的各种相关关系。
可以通过多种方法基于异构图信息构建同构视图,例如,包括特征相似度构图方法,计算异构图信息中的节点特征的相似度,在相似度超过一定阈值的节点之间连边,构建表征节点特征语义关系的同构视图。元路径关系构图方法,预定义以目标类型的节点为头尾的元路径类型,如果两节点在异构图上以某条元路径相连,则在它们之间连边,构建表示目标类型的节点间高阶关系的同构视图(如图4所示,以Movie-Director-Movie元路径构建G1)。
可以理解的是,元路径被定义为一个路径Ф=A1→A2→…→Al+1,描述了节点类型A1和Al+1之间的复合关系R=R1°R2°…°RL。举例说明,如图4所示,电影类型节点M2、M3均与导演类型节点D1具有连接边,则具有元路径M2-D1-M3,表示同一个导演导的电影。
根据实际应用中选用的不同特征相似度计算方法、不同元路径类型、其他构图方法,可生成多个同构视图。这些同构视图表征包含丰富的信息,可以通过对比帮助HGN模型训练。
相似度计算方法包括但不限于余弦相似度、欧式距离、马氏距离、曼哈顿距离或其他用于计算相似度的函数等,此处不做穷举。
需要解释道是。若无特殊说明,本文中提及的异构图与异构视图含义相同,同构图与同构视图含义相同,同构子图与同构子视图含义相同。
在步骤S303中,将异构视图输入异构图神经网络,得到目标类型的节点中各个节点的第一表征向量。
将异构视图输入异构图神经网络,异构图神经网络对异构视图中的目标类型的节点进行计算得到异构视图中目标类型的节点中各个节点的第一表征向量,也就是说,提取得到异构视图中目标类型的节点中各个节点的嵌入表征向量(Embedding Vector)。
示例性的,将异构视图输入异构图神经网络,异构图神经网络对异构视图中的目标类型的节点进行计算,提取得到目标类型的节点与其他类型的节点的交互信息,作为表征目标类型的节点的第一表征向量。
在步骤S304中,将目标类型的节点中各个节点的第一表征向量输入分类器得到目标类型的节点中各个节点的预测类别的概率分布。
将异构图神经网络输出目标类型的节点中的各个节点的第一表征向量输入分类器(例如,线性层)中,分类器对目标类型的节点中的各个节点的第一表征向量进行分类计算得到目标类型的节点中各个节点的预测类别的概率分布,也可以称之为目标类型的节点的归一化预测分数
在步骤S305中,基于目标类型的节点中的部分节点的预测类别的概率分布和目标类型的节点中的部分节点的标签,得到监督损失。
例如,目标类型的节点的标签可以“1”或“0”表示节点为A类别或非A类别,针对某一目标类型的节点i,当分类器输出的节点i的预测类别为A的概率分布为0.8时,查询标签样本集L对应的节点i标注的标签为1,则节点i的监督损失则为1减去0.8等于0.2。
在步骤S306中,将同构视图输入同构图神经网络,得到目标类型的节点中各个节点的第二表征向量。
将同构视图输入同构图神经网络,同构图神经网络对同构视图中的目标类型的节点进行计算得到同构视图中各个节点的第二表征向量,也就是说,提取得到同构视图中各个节点的嵌入表征向量(Embedding Vector)。
由步骤S302中部分描述可知,同构视图包括目标类型的节点,以及表征目标类型的节点间相关关系的连接边。同构视图刻画了目标类型的节点间的各种相关关系。因此,同构图神经网络输出的为表征目标类型的节点间关系的嵌入表征向量。
示例性的,将同构视图输入同构图神经网络,同构图神经网络对同构视图中的各个节点进行计算,提取得到目标类型的节点间的各种关系信息(例如特征语义关系和目标节点间的高阶关系),作为表征目标类型的节点间的各种关系语义的第二表征向量。
在步骤S307中,基于正样本对的相似度和负样本对的相似度,确定对比损失。
其中,正样本对包括异构视图中的目标类型的第一节点和同构视图中的目标类型的第二节点,第一节点和第二节点具有相同的标签,负样本对包括目标类型的第一节点和同构视图中的目标类型的第三节点,第一节点和第三节点具有不同的标签。
也就是说,正样本对为标签一致的异构图视中目标类型的节点和同构视图中的目标类型的节点,负样本对为标签不一致的异构图视图中目标类型的节点和同构视图中的目标类型的节点。
例如,对于一个目标类型的节点i,如果节点i在标签样本集中,我们将所有在标签样本集L中且与它有相同标签的样本加入正样本候选集Ci;如果不在标签样本集/>中,我们无法判断哪些样本是与它同类的,其正样本候选集Ci中仅有它本身,正样本候选集Ci的确定方式可参见下述公式:
对于一个目标类型的节点i,与正样本候选集Ci中的任一节点组成正样本对,从目标类型的节点中非Ci样本集中的任一节点组成负样本对。
正样本对的相似度基于第一节点的第一表征向量和第二节点的第二表征向量的相似度得到,负样本对的相似度基于第一节点的第一表征向量和第三节点的第二表征向量的相似度得到。
在一个示例中,对于任一目标类型的节点i的对比损失,利用如下对比损失函数计算:
对比损失:
其中sim(·)是相似度函数(例如余弦相似度),t是温度系数。
节点i的对比损失与正样本对的相似度正相关,与负样本对的相似度负相关,以使正样本对之间的特征更加相似,使负样本对之间的特征更加的不相似,使得正样本对比负样本对在表征空间中更加接近。
在步骤S308中,以最小化监督损失和对比损失为目标,调整异构图神经网络的参数和分类器的参数,以得到训练后的分类模型。
联合优化监督损失和对比损失,反向传播更新分类模型的参数,也就是说,分类模型的损失函数为监督损失与对比损失的和,分类模型的损失函数如下:
其中,λ为可调超参数。
容易理解的是,上文提及的调整异构图神经网络的参数和分类器的参数,是指在训练过程中,调整异构图神经网络的权重参数和调整分类器的权重参数。
本申请实施例提供的一种分类模型的训练方法,首先采用对比学习的方法实现了在少标签的情况下,完成对模型的训练,其次,在构建对比视图中以原始异构图作为异构视图,表征目标类型的节点与其他类型节点间的相关关系,同时还额外构建了同构视图表征目标类型的节点间的相关关系,有益于下游的分类任务(目标类型的节点间的关系对于下游分类任务很重要),在构建对比样本对(正样本对和负样本对)时,考虑标签信息,统一了对比学习和下游分类任务的优化目标,有效提升了节点分类的准确度。
训练后的分类模型可以应用于各种具体的业务中以实现分类或识别的功能。将业务数据对应的异构图输入分类模型中,输出异构图中目标类型的节点中各个节点的预测标签,实现对目标类型的节点的识别或分类。
可以理解的是,异构图中的节点表征业务数据中的各种实体数据,例如,目标类型的节点表征业务数据中待分类或待识别的实体,例如,业务数据可以为电子商务相关的业务数据,该业务数据包括多种实体,例如交易、买家、电子邮箱、付款方式、邮寄地址等实体,电子商务相关的业务数据对应的异构图中的各个节点分别表征交易、买家、电子邮箱、付款方式、邮寄地址等实体,目标类型的节点表征待分类的业务实体,例如交易,则将电子商务相关的业务数据对应的异构图输入分类模型中,输出异构图中各个交易节点的预测标签,例如欺诈/安全,以识别该交易是否为安全交易。
再例如,业务数据可以为影视相关的业务数据,该业务数据包括多种实体,例如演员、电影和导演等实体,影视相关的业务数据对应的异构图中的各个节点分别表征演员、电影和导演等实体,目标类型的节点表征待分类的业务实体,例如演员,则将影视相关的业务数据对应的异构图输入分类模型中,输出异构图中各个演员节点的预测标签,例如参演了那部电影。应当理解,此处举例仅为方便对本申请实施例的应用场景进行理解,不对本申请实施例的应用场景进行穷举。
应用本申请实施例提供分类模型的训练方法,无需较多标注数据,可以较快的训练出符合要求的分类模型,节省了训练设备(例如训练服务器)的计算资源。
在一个示例中,步骤S302中构建的同构视图包括多个同构子视图,如图2a中的G1,G2…Gm,包括M个同构子视图,每个同构子视图对应一个同构图神经网络,例如,图卷积神经网络(graph convolutional neural network,GCN),也就是说,同构图神经网络也包括M个。
如图2a所示,将M个同构子视图输入M个同构图神经网络,得到M个同构子视图中各个同构子视图的视图表征H1,H2…Hm,将各个同构子视图的视图表征进行融合,得到融合后的同构视图的视图表征Hho。
可以通过多种方法基于异构图信息构建同构子视图,例如,包括特征相似度构图方法,计算异构图信息中的节点特征的相似度,在相似度超过一定阈值的节点之间连边,构建表征节点特征语义关系的同构子视图。元路径关系构图方法,预定义以目标类型的节点为头尾的元路径类型,如果两节点在异构图上以某条元路径相连,则在它们之间连边,构建表示目标类型的节点间高阶关系的同构子视图(如图4所示,以Movie-Director-Movie元路径构建G1)。
根据实际应用中选用的不同特征相似度计算方法、不同元路径类型、其他构图方法,可生成多个同构视图。这些同构视图表征包含丰富的信息,可以通过对比帮助HGN模型训练。
相似度计算方法包括但不限于余弦相似度、欧式距离、马氏距离、曼哈顿距离或其他用于计算相似度的函数等。
在一些其他的可能实现中,M个同构子视图还可以以一种方法构建,例如,M个同构子视图基于目标类型的节点对应的多条不同元路径构建,也就是说多条不同元路径构建多个同构子视图。
实现多个同构子视图的融合方法可以为:对各个同构子视图的视图表征进行注意力运算,得到各个同构子视图的注意力分数;基于各个同构子视图的注意力分数,得到各个同构子视图的权重;基于各个同构子视图的权重,对各个同构子视图的表征进行加权聚合,得到同构视图的视图表征。
各个同构子视图的注意力分数获取方法为:将各个同构子视图的视图表征通过映射矩阵,映射至公共特征空间,得到各个同构子视图的转换视图表征;基于注意力机制对各个同构子视图的转换视图表征进行注意力运算,得到各个同构子视图的注意力分数。
可以理解的是,注意力机制是用来学习和计算各个同构子视图的视图表征对融合后的同构视图的视图表征的贡献大小,也就是说,注意力分数越高的同构子视图,在融合过程中的权重越大,以突出注意力分数高的同构子视图的视图表征对融合后的同构视图的视图表征的影响。
示例性的,对于某一同构子视图Gj对应的所有节点表征Hj,其中的某一节点i经过线性映射后计算注意力分数这里Watt和/>是可学习的映射矩阵和注意力向量,同构子视图Gj上所有节点集合V的注意力分数均值即为该视图的注意力分数wj。注意力分数wj通过下述公式计算得到:
各视图的注意力分数经过softmax处理后作为权重,加权聚合节点表征,得到融合后的同构视图表征:
目标类型的节点之间的关系对于分类任务十分重要,但是现有方法选择的对比视图无法直接建模目标类型的节点之间的关系。本申请实施例提供的分类模型的训练方法,建立以目标类型的节点为节点集合的同构视图,直接建模目标类型的节点间的各种关系,从而有益于下游任务。其中,不仅设计了两种构图方式以充分挖掘目标类型的节点间在特征空间、结构空间中的语义关系,而且还提供了多个同构子视图表征融合机制,使得在使用中可以依据需要建立更多的同构子视图,得到融合的同构视图表征,充分发挥对比学习作用。
表1展示了本申请与包括CPT-HG和HeCo在内的现有技术的区别。
表一
在一个示例中,本申请实施例提供的一种分类模型的训练方法,还包括基于目标类型的节点中各个节点的预测类别的概率分布,得到目标类型的节点中各个节点的预测置信度;将预测置信度达标的目标类型的节点,以及预测置信度达标的目标类型的节点对应的预测标签,补充至标签样本集。也就是说,本申请实施例提供的分类模型的训练方法还包括,扩充标签样本集的步骤。
例如,将预测置信度大于或等于阈值的目标类型的节点确定为预测置信度达标的目标类型的节点。
可选的,该阈值可以为用户设置,例如阈值为0.9,即将预测置信度大于0.9的目标类型的节点补充至标签样本集,也可称之为可信已标注集合
可选的,阈值基于具有预先标注的标签的目标类型的节点对应的预测置信度均值得到;预测置信度均值表征具有预先标注的标签的目标类型的节点中的各个节点对应的预测置信度的平均值。
示例性的,首先为所有目标类型的节点计算标签预测置信度zi,然后根据标签预测置信度扩充可信已标注集合表示那些模型认为预测标签与真实标签相同的样本集合,帮助我们从无标签样本中挑选出同类样本。
对于目标类型的节点i,利用分类器输出的标签预测概率分布Pi,计算标签预测置信度zi,以衡量模型对节点分类结果的确信程度。若Pi在c个类别上概率分布具有较大峰度,则标签预测置信度较高。
初始化为L。之后,动态地将标签预测置信度超过阈值的节点加入集合。阈值由有标签样本当前的标签预测置信度的均值动态决定。
现有技术大多采用预训练+微调的模式,在选取对比样本时只考虑自监督信息,没有拉近同类别样本的距离,使得对比损失与下游任务优化目标不一致,影响了预测效果的进一步提升。本申请实施例提供的分类模型的训练方法以半监督范式,在构建对比样本对时考虑标签信息,将模型预测标签相同且模型认为预测结果可信的同类节点作为正样本对,对比损失与有监督损失协同优化,统一了对比学习与下游任务的优化目标。同时,自适应地扩充可信已标注集合,以弥补标注数据的不足。本申请实施例提供的分类模型的训练方法与包括CPT-HG和HeCo在内的现有技术的改进之处如表二所示。
表二
本申请实施例还提供了一种分类方法,具体实现如图5所示,包括步骤S501和步骤S502。
在步骤S501中,获取异构图信息。
异构图信息包括多种类型的节点,和表征节点间的相关关系的连接边,多种类型的节点包括目标类型的节点,目标类型的节点Vtar为异构图节点分类问题中待分类的节点类型的全部节点。
在步骤S502中,将异构图信息输入分类模型,得到目标类型的节点中各个节点的预测类别。
其中,分类模型基于本申请实施例提供的分类模型的训练方法训练得到。
分类模型的具体训练方法,可参见上文分类模型的训练方法部分描述,为了简洁,这里不再赘述。
为了验证本申请在异构图节点分类问题上的效果,我们在IMDB影视资料库中的电影分类任务上进行了实验。如图6所示,该数据集有电影(M)、演员(A)、导演(D)3种节点类型,目标类型节点为电影。实验时,首先预热50个周期不更新可信已标注样本集,然后在每个周期根据模型预测结果进行扩充。我们分别采样20\40\60个样本作为有标签样本集,1000个样本作为测试集,与前述的异构图表征学习方法HeCo进行比较。所有实验重复5次,表三中报告了两种方法在三个指标上多次实验的平均值和标准差。
表三
为了进一步验证本申请中构建多个视图、融合视图表征的效果,我们分别采用特征相似度和选取不同元路径模式构建同构视图,并在有多个同构视图情况下使用融合视图表征机制进行实验。与不进行对比学习、只使用HGN编码原始异构图的方法相比,不同同构视图选择在5次实验下F1-score的增益的平均增益均值汇总在表四中。
表四
现有异构图对比学习方法构建的对比视图不能充分表达节点关系,或存在噪声。本申请充分挖掘异构图中节点的不同关系。构建多个对比视图充分表征各种关系,并使用视图表征融合机制以动态融合不同视图表征。
现有异构图对比学习方法,其对比损失与分类损失的目标函数不一致导致知识的负迁移,影响分类效果。在本申请构建对比样本对时考虑标签信息,动态扩充有标签集合,统一了对比学习与下游任务的优化目标,在节点分类准确度上达到了明显优于业界一流方案的效果。
与前述一种分类模型的训练方法的实施例基于相同的构思,本申请实施例中还提供了一种分类模型的训练装置700,该分类模型的训练装置700包括用以实现图2a-4所示的分类模型的训练方法中的各个步骤的单元或模块。
图7为本申请实施例提供的一种分类模型的训练装置的结构示意图。该装置应用于计算设备,如图7所示,该一种分类模型的训练装置700至少包括:
获取模块701,用于获取训练数据集,所述训练数据集包括异构图信息和标签样本集,所述异构图信息包括多种类型的节点,所述多种类型的节点包括目标类型的节点,所述标签样本集包括所述目标类型的节点中的部分节点,所述部分节点具有预先标注的标签;
对比视图构建模块702,用于基于所述异构图信息构建多个对比视图,所述对比视图包括异构视图和同构视图,所述异构视图包括所述多种类型的节点,所述同构视图包括多个所述目标类型的节点;
第一表征向量确定模块703,用于将所述目标类型的节点中各个节点的第一表征向量输入分类器,得到所述目标类型的节点中各个节点的预测类别的概率分布;
监督损失确定模块704,用于基于所述目标类型的节点中的部分节点的预测类别的概率分布和所述目标类型的节点中的部分节点的标签,得到监督损失;
第二表征向量确定模块705,用于将所述同构视图输入同构图神经网络,得到所述目标类型的节点中各个节点的第二表征向量;
对比损失确定模块706,用于基于正样本对的相似度和负样本对的相似度,确定对比损失;
其中,所述正样本对包括所述异构视图中的所述目标类型的第一节点和所述同构视图中的所述目标类型的第二节点,所述第一节点和第二节点具有相同的标签,所述负样本对包括所述目标类型的第一节点和所述同构视图中的所述目标类型的第三节点,所述第一节点和第三节点具有不同的标签,所述正样本对的相似度基于所述第一节点的第一表征向量和第二节点的第二表征向量的相似度得到,所述负样本对的相似度基于所述第一节点的第一表征向量和第三节点的第二表征向量的相似度得到;
调参模块707,用于以最小化所述监督损失和所述对比损失为目标,调整所述异构图神经网络和分类器的参数,以得到训练后的所述分类模型。
在一个可能的实现中,所述同构视图包括M个同构子视图,所述同构图神经网络包括M个同构图神经网络,所述M为大于1的正整数;
所述第二表征向量确定模块705具体用于:
将所述M个同构子视图输入M个同构图神经网络,得到所述M个同构子视图中各个同构子视图的视图表征;
将所述各个同构子视图的视图表征进行融合,得到所述同构视图的视图表征,所述同构视图的视图表征包括所述目标类型的节点中的各个节点的第二表征向量。
在另一个可能的实现中,所述将所述各个同构子视图的视图表征进行融合,得到所述同构视图的视图表征,包括:
对所述各个同构子视图的视图表征进行注意力运算,得到所述各个同构子视图的注意力分数;
基于所述各个同构子视图的注意力分数,得到所述各个同构子视图的权重;
基于所述各个同构子视图的权重,对所述各个同构子视图的表征进行加权聚合,得到所述同构视图的视图表征。
在另一个可能的实现中,所述对所述各个同构子视图的视图表征进行注意力运算,得到所述各个同构子视图的注意力分数,包括:
将所述各个同构子视图的视图表征通过映射矩阵,映射至公共特征空间,得到所述各个同构子视图的转换视图表征;
基于注意力机制对所述各个同构子视图的转换视图表征进行注意力运算,得到所述各个同构子视图的注意力分数。
在另一个可能的实现中,所述M个同构子视图包括第一同构子视图和第二同构子视图;
所述第一同构子视图基于所述目标类型的节点间的相似度构建;
所述第二同构子视图基于所述目标类型的节点对应的元路径构建。
在另一个可能的实现中,所述M个同构子视图基于所述目标类型的节点对应的多条不同元路径构建。
在另一个可能的实现中,还包括标签样本集扩充模块708,用于基于所述目标类型的节点中各个节点的预测类别的概率分布,得到所述目标类型的节点中各个节点的预测置信度;
将所述预测置信度达标的所述目标类型的节点,以及所述预测置信度达标的所述目标类型的节点对应的预测标签,补充至所述标签样本集。
在另一个可能的实现中,将预测置信度大于或等于阈值的所述目标类型的节点确定为所述预测置信度达标的所述目标类型的节点。
在另一个可能的实现中,所述阈值基于所述具有预先标注的标签的所述目标类型的节点对应的预测置信度均值得到;
所述预测置信度均值表征所述具有预先标注的标签的所述目标类型的节点中的各个节点对应的预测置信度的平均值。
在另一个可能的实现中,所述对比损失与所述正样本对的相似度正相关,与所述负样本对的相似度负相关。
根据本申请实施例的分类模型的训练装置700可对应于执行本申请实施例中描述的方法,并且一种分类模型的训练装置700中的各个模块的上述和其它操作和/或功能分别为了实现图2a-4中的各个方法的相应流程,为了简洁,在此不再赘述。
与前述一种分类方法的实施例基于相同的构思,本申请实施例中还提供了一种分类装置800,该分类装置800包括用以实现图5所示的分类方法中的各个步骤的单元或模块。
图8为本申请实施例提供的一种分类装置的结构示意图。该装置应用于计算设备,如图8所示,该一种分类装置800至少包括:
获取模块801,用于获取异构图信息,所述异构图信息包括多种类型的节点,所述多种类型的节点包括目标类型的节点;
推理模块802,用于将所述异构图信息输入分类模型,得到所述目标类型的节点中各个节点的预测类别;
其中,所述分类模型基于本申请实施例提供的分类模型的训练方法训练得到。
根据本申请实施例的分类装置800可对应于执行本申请实施例中描述的方法,并且一种分类装置800中的各个模块的上述和其它操作和/或功能分别为了实现图5中的各个方法的相应流程,为了简洁,在此不再赘述。
本申请实施例还提供一种计算设备,包括至少一个处理器、存储器和通信接口,所述处理器用于执行图2a-5所述的方法。
图9为本申请实施例提供的计算设备的结构示意图。
如图9所示,所述计算设备900包括至少一个处理器901、存储器902、图形处理器(graphics processing unit,GPU)903和通信接口904。其中,处理器901、存储器902、图形处理器903和通信接口904通信连接,可以通过有线(例如总线)的方式实现通信连接,也可以通过无线的方式实现通信连接。该通信接口904用于接收其他设备发送的数据(例如异构图信息);存储器902存储有计算机指令,处理器901执行该计算机指令,执行前述方法实施例中的分类模型的训练方法和/或分类方法。图形处理器903中部署有分类模型。
应理解,在本申请实施例中,该处理器901可以是中央处理单元CPU,该处理器901还可以是其他通用处理器、数字信号处理器(digital signal processor,DSP)、专用集成电路(application specific integrated circuit,ASIC)、现场可编程门阵列(fieldprogrammable gate array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者是任何常规的处理器等。
该存储器902可以包括只读存储器和随机存取存储器,并向处理器901提供指令和数据。存储器902还可以包括非易失性随机存取存储器。
该存储器902可以是易失性存储器或非易失性存储器,或可包括易失性和非易失性存储器两者。其中,非易失性存储器可以是只读存储器(read-only memory,ROM)、可编程只读存储器(programmable ROM,PROM)、可擦除可编程只读存储器(erasable PROM,EPROM)、电可擦除可编程只读存储器(electrically EPROM,EEPROM)或闪存。易失性存储器可以是随机存取存储器(random access memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(static RAM,SRAM)、动态随机存取存储器(DRAM)、同步动态随机存取存储器(synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(double data date SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(synchlinkDRAM,SLDRAM)和直接内存总线随机存取存储器(direct rambus RAM,DR RAM)。
应理解,根据本申请实施例的计算设备900可以执行实现本申请实施例中图2a-5所示方法,该方法实现的详细描述参见上文,为了简洁,在此不再赘述。
本申请的实施例提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机指令在被处理器执行时,使得上文提及的分类模型的训练方法和/或分类方法被实现。
本申请的实施例提供了一种芯片,该芯片包括至少一个处理器和接口,所述至少一个处理器通过所述接口确定程序指令或者数据;该至少一个处理器用于执行所述程序指令,以实现上文提及的分类模型的训练方法和/或分类方法。
本申请的实施例提供了一种计算机程序或计算机程序产品,该计算机程序或计算机程序产品包括指令,当该指令执行时,令计算机执行上文提及的分类模型的训练方法和/或分类方法。
本领域普通技术人员应该还可以进一步意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执轨道,取决于技术方案的特定应用和设计约束条件。本领域普通技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
结合本文中所公开的实施例描述的方法或算法的步骤可以用硬件、处理器执轨道的软件模块,或者二者的结合来实施。软件模块可以置于随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、硬盘、可移动磁盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质中。
以上所述的具体实施方式,对本申请的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本申请的具体实施方式而已,并不用于限定本申请的保护范围,凡在本申请的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (24)
1.一种分类模型的训练方法,其特征在于,所述分类模型包括异构图神经网络和分类器,所述分类模型用于预测异构图中同一类型节点的标签的概率分布,所述方法包括:
获取数据集,所述数据集包括异构图和所述异构图中部分节点的标签,所述部分节点的类型相同;
将所述异构图输入所述分类模型,以最小化监督损失函数和对比损失函数为目标,调整所述异构图神经网络的参数和所述分类器的参数,以得到训练后的所述分类模型;
其中,所述监督损失函数指示所述部分节点预测的标签的概率分布和所述部分节点的标签之间的关系;所述对比损失函数指示正样本对的相似度与负样本对的相似度之间的关系,所述正样本对包括第一节点和第二节点,所述负样本对包括所述第一节点和第三节点,所述第一节点是异构图中的节点,所述第二节点和所述第三节点是与所述异构图对应的所述同构图中的节点,所述同构图中的节点的类型与所述部分节点的类型相同,所述第一节点和所述第二节点具有相同的类型和相同的标签,所述第一节点和所述第三节点具有相同的类型和不同的标签。
2.根据权利要求1所述的方法,其特征在于,还包括,将所述同构图作为同构图神经网络的输入,输出所述同构图的节点中各个节点的第二表征向量;
所述正样本对的相似度基于所述第一节点的第一表征向量和第二节点的第二表征向量的相似度得到,所述负样本对的相似度基于所述第一节点的第一表征向量和第三节点的第二表征向量的相似度得到;其中,所述同一类型节点中各个节点的第一表征向量基于所述异构图作为所述异构图神经网络的输入,所述异构图神经网络输出得到。
3.根据权利要求2所述的方法,其特征在于,所述同构图包括M个同构子图,所述M为大于1的正整数;
所述将所述同构图作为同构图神经网络的输入,输出所述同构图的节点中各个节点的第二表征向量,包括:
将所述M个同构子图分别输入其各自对应的同构图神经网络,得到所述M个同构子图中各个同构子图的视图表征;
将所述各个同构子图的视图表征进行融合,得到所述同构图的视图表征,所述同构图的视图表征包括所述同构图的节点中的各个节点的第二表征向量。
4.根据权利要求3所述的方法,其特征在于,所述将各个同构子图的视图表征进行融合,得到所述同构图的视图表征,包括:
基于所述各个同构子图的权重,对所述各个同构子图的表征进行加权聚合,得到所述同构图的视图表征,所述各个同构子图的权重与各个同构子图的注意力分数有关。
5.根据权利要求4所述的方法,其特征在于,所述各个同构子图的注意力分数基于注意力机制对所述各个同构子图的转换视图表征进行注意力运算得到;
所述各个同构子图的转换视图表征为将所述各个同构子图的视图表征通过映射矩阵,映射至公共特征空间得到。
6.根据权利要求3-5任一项所述的方法,其特征在于,所述M个同构子图包括第一同构子图和第二同构子图;
所述第一同构子图基于所述同一类型节点的节点间的相似度构建;
所述第二同构子图基于所述同一类型节点对应的元路径构建。
7.根据权利要求3-5任一项所述的方法,其特征在于,所述M个同构子图基于所述同一类型节点对应的多条不同的元路径构建。
8.根据权利要求1-7任一项所述的方法,其特征在于,还包括:基于所述同一类型节点中各个节点的预测标签的概率分布,得到所述同一类型节点中各个节点的预测置信度;
将所述预测置信度达标的所述同一类型节点以所述预测标签标注标签。
9.根据权利要求8所述的方法,其特征在于,将预测置信度大于或等于阈值的所述同一类型节点确定为所述预测置信度达标的所述同一类型节点。
10.根据权利要求9所述的方法,其特征在于,所述阈值基于所述部分节点对应的预测置信度均值得到;
所述预测置信度均值表征所述部分节点中的各个节点对应的预测置信度的平均值。
11.一种分类方法,其特征在于,包括:
获取异构图,所述异构图包括多种类型的节点;
将所述异构图输入分类模型,得到同一类型节点中各个节点的预测标签;
其中,所述分类模型基于所述权利要求1-10任一项所述的方法训练得到。
12.一种分类模型的训练装置,其特征在于,所述分类模型包括异构图神经网络和分类器,所述分类模型用于预测异构图中同一类型节点的标签的概率分布,所述训练装置包括:
获取模块,用于获取数据集,所述数据集包括异构图和所述异构图中部分节点的标签,所述部分节点的类型相同;
调参模块,用于将所述异构图输入所述分类模型,以最小化监督损失函数和对比损失函数为目标,调整所述异构图神经网络的参数和所述分类器的参数,以得到训练后的所述分类模型;
其中,所述监督损失函数指示所述部分节点预测的标签的概率分布和所述部分节点的标签之间的关系;所述对比损失函数指示正样本对的相似度与负样本对的相似度之间的关系,所述正样本对包括第一节点和第二节点,所述负样本对包括所述第一节点和第三节点,所述第一节点是异构图中的节点,所述第二节点和所述第三节点是与所述异构图对应的所述同构图中的节点,所述同构图中的节点的类型与所述部分节点的类型相同,所述第一节点和所述第二节点具有相同的类型和相同的标签,所述第一节点和所述第三节点具有相同的类型和不同的标签。
13.根据权利要求12所述的装置,其特征在于,还包括第二表征向量确定模块,
用于将所述同构图作为同构图神经网络的输入,输出所述同构图的节点中各个节点的第二表征向量;
所述正样本对的相似度基于所述第一节点的第一表征向量和第二节点的第二表征向量的相似度得到,所述负样本对的相似度基于所述第一节点的第一表征向量和第三节点的第二表征向量的相似度得到;其中,所述同一类型节点中各个节点的第一表征向量基于所述异构图作为所述异构图神经网络的输入,所述异构图神经网络输出得到。
14.根据权利要求13所述的装置,其特征在于,所述同构图包括M个同构子图,所述M为大于1的正整数;
所述第二表征向量确定模块还用于:
将所述M个同构子图分别输入其各自对应的同构图神经网络,得到所述M个同构子图中各个同构子图的视图表征;
将所述各个同构子图的视图表征进行融合,得到所述同构图的视图表征,所述同构图的视图表征包括所述同构图的节点中的各个节点的第二表征向量。
15.根据权利要求14所述的装置,其特征在于,所述将各个同构子图的视图表征进行融合,得到所述同构图的视图表征,包括:
基于所述各个同构子图的权重,对所述各个同构子图的表征进行加权聚合,得到所述同构图的视图表征,所述各个同构子图的权重与各个同构子图的注意力分数有关。
16.根据权利要求15所述的装置,其特征在于,所述各个同构子图的注意力分数基于注意力机制对所述各个同构子图的转换视图表征进行注意力运算得到;
所述各个同构子图的转换视图表征为将所述各个同构子图的视图表征通过映射矩阵,映射至公共特征空间得到。
17.根据权利要求14-16任一项所述的装置,其特征在于,所述M个同构子图包括第一同构子图和第二同构子图;
所述第一同构子图基于所述同一类型节点的节点间的相似度构建;
所述第二同构子图基于所述同一类型节点对应的元路径构建。
18.根据权利要求14-16任一项所述的装置,其特征在于,所述M个同构子图基于所述同一类型节点对应的多条不同的元路径构建。
19.根据权利要求12-18任一项所述的装置,其特征在于,还包括标签样本集扩充模块,还包括:基于所述同一类型节点中各个节点的预测标签的概率分布,得到所述同一类型节点中各个节点的预测置信度;
将所述预测置信度达标的所述同一类型节点以所述预测标签标注标签。
20.根据权利要求19所述的装置,其特征在于,将预测置信度大于或等于阈值的所述同一类型节点确定为所述预测置信度达标的所述同一类型节点。
21.根据权利要求20所述的装置,其特征在于,所述阈值基于所述部分节点对应的预测置信度均值得到;
所述预测置信度均值表征所述部分节点中的各个节点对应的预测置信度的平均值。
22.一种分类装置,其特征在于,包括:
获取模块,用于获取异构图,所述异构图包括多种类型的节点;
推理模块,用于将所述异构图输入分类模型,得到同一类型节点中各个节点的预测标签;
其中,所述分类模型基于所述权利要求1-10任一项所述的方法训练得到。
23.一种计算设备,包括存储器和处理器,其特征在于,所述存储器中存储有指令,当所述指令被处理器执行时,使得如权利要求1-10任一项所述的方法,和/或如权利要求11所述的方法被实现。
24.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,当所述计算机程序在被处理器执行时,使得如权利要求1-10任一项所述的方法,和/或如权利要求11所述的方法被实现。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210462062.9A CN117033992A (zh) | 2022-04-28 | 2022-04-28 | 一种分类模型的训练方法及装置 |
PCT/CN2023/089797 WO2023207790A1 (zh) | 2022-04-28 | 2023-04-21 | 一种分类模型的训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210462062.9A CN117033992A (zh) | 2022-04-28 | 2022-04-28 | 一种分类模型的训练方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117033992A true CN117033992A (zh) | 2023-11-10 |
Family
ID=88517568
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210462062.9A Pending CN117033992A (zh) | 2022-04-28 | 2022-04-28 | 一种分类模型的训练方法及装置 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN117033992A (zh) |
WO (1) | WO2023207790A1 (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113849645A (zh) * | 2021-09-28 | 2021-12-28 | 平安科技(深圳)有限公司 | 邮件分类模型训练方法、装置、设备及存储介质 |
Families Citing this family (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117807275A (zh) * | 2023-12-29 | 2024-04-02 | 江南大学 | 基于关系挖掘的异构图嵌入方法及系统 |
CN117976139A (zh) * | 2024-03-29 | 2024-05-03 | 武汉纺织大学 | 一种基于纠偏机制和对比学习的药物重定位方法及系统 |
Family Cites Families (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111309983B (zh) * | 2020-03-10 | 2021-09-21 | 支付宝(杭州)信息技术有限公司 | 基于异构图进行业务处理的方法及装置 |
CN111400560A (zh) * | 2020-03-10 | 2020-07-10 | 支付宝(杭州)信息技术有限公司 | 一种基于异构图神经网络模型进行预测的方法和系统 |
CN112966763B (zh) * | 2021-03-17 | 2023-12-26 | 北京邮电大学 | 一种分类模型的训练方法、装置、电子设备及存储介质 |
CN113326884B (zh) * | 2021-06-11 | 2023-06-16 | 之江实验室 | 大规模异构图节点表示的高效学习方法及装置 |
CN114239711A (zh) * | 2021-12-06 | 2022-03-25 | 中国人民解放军国防科技大学 | 基于异构信息网络少样本学习的节点分类方法 |
-
2022
- 2022-04-28 CN CN202210462062.9A patent/CN117033992A/zh active Pending
-
2023
- 2023-04-21 WO PCT/CN2023/089797 patent/WO2023207790A1/zh unknown
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113849645A (zh) * | 2021-09-28 | 2021-12-28 | 平安科技(深圳)有限公司 | 邮件分类模型训练方法、装置、设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
WO2023207790A1 (zh) | 2023-11-02 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN117033992A (zh) | 一种分类模型的训练方法及装置 | |
Li et al. | A knowledge-driven anomaly detection framework for social production system | |
CN110852447A (zh) | 元学习方法和装置、初始化方法、计算设备和存储介质 | |
Han et al. | Sentiment analysis via semi-supervised learning: a model based on dynamic threshold and multi-classifiers | |
Zhuge et al. | Joint consensus and diversity for multi-view semi-supervised classification | |
CN115048586B (zh) | 一种融合多特征的新闻推荐方法及系统 | |
Ma et al. | Class-imbalanced learning on graphs: A survey | |
CN112257959A (zh) | 用户风险预测方法、装置、电子设备及存储介质 | |
Yang et al. | Out-of-distribution detection with semantic mismatch under masking | |
Li et al. | Transductive distribution calibration for few-shot learning | |
Zhu et al. | CCBLA: a lightweight phishing detection model based on CNN, BiLSTM, and attention mechanism | |
Wang et al. | Few-shot node classification with extremely weak supervision | |
Shen et al. | UniSKGRep: A unified representation learning framework of social network and knowledge graph | |
Bai et al. | Benchmarking tropical cyclone rapid intensification with satellite images and attention-based deep models | |
Qu et al. | Improving the reliability for confidence estimation | |
Du et al. | Structure tuning method on deep convolutional generative adversarial network with nondominated sorting genetic algorithm II | |
CN117349494A (zh) | 空间图卷积神经网络的图分类方法、系统、介质及设备 | |
Eom et al. | Multi-task learning for spatial events prediction from social data | |
Fu et al. | Multi-label learning with kernel local label information | |
Hua et al. | Robust and sparse label propagation for graph-based semi-supervised classification | |
Wu et al. | Learning deep networks with crowdsourcing for relevance evaluation | |
Zhu et al. | A hybrid model for nonlinear regression with missing data using quasilinear kernel | |
Tsai et al. | Predicting online news popularity based on machine learning | |
Zhou et al. | Robust graph structure learning for multimedia data analysis | |
CN115878882A (zh) | 用户兴趣的分层表示学习 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication |