CN117473315A - 一种基于多层感知机的图分类模型构建方法和图分类方法 - Google Patents
一种基于多层感知机的图分类模型构建方法和图分类方法 Download PDFInfo
- Publication number
- CN117473315A CN117473315A CN202311423387.7A CN202311423387A CN117473315A CN 117473315 A CN117473315 A CN 117473315A CN 202311423387 A CN202311423387 A CN 202311423387A CN 117473315 A CN117473315 A CN 117473315A
- Authority
- CN
- China
- Prior art keywords
- graph
- classification
- model
- data
- layer perceptron
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 92
- 238000013145 classification model Methods 0.000 title claims abstract description 55
- 238000010276 construction Methods 0.000 title claims abstract description 12
- 238000012549 training Methods 0.000 claims abstract description 81
- 238000013528 artificial neural network Methods 0.000 claims abstract description 60
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 24
- 238000004891 communication Methods 0.000 claims description 19
- 230000006870 function Effects 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 7
- 238000003860 storage Methods 0.000 claims description 7
- 230000008569 process Effects 0.000 abstract description 14
- 230000008901 benefit Effects 0.000 abstract description 2
- 238000012821 model calculation Methods 0.000 abstract 1
- 238000012545 processing Methods 0.000 description 14
- 238000004422 calculation algorithm Methods 0.000 description 6
- 238000004364 calculation method Methods 0.000 description 6
- 238000002474 experimental method Methods 0.000 description 5
- 230000006399 behavior Effects 0.000 description 4
- 230000003993 interaction Effects 0.000 description 4
- 238000004519 manufacturing process Methods 0.000 description 4
- 238000005457 optimization Methods 0.000 description 4
- 238000007418 data mining Methods 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 239000000284 extract Substances 0.000 description 3
- 239000011159 matrix material Substances 0.000 description 3
- 238000011176 pooling Methods 0.000 description 3
- 238000011160 research Methods 0.000 description 3
- 238000012546 transfer Methods 0.000 description 3
- 239000013598 vector Substances 0.000 description 3
- 230000002776 aggregation Effects 0.000 description 2
- 238000004220 aggregation Methods 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 2
- 238000004821 distillation Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 230000008451 emotion Effects 0.000 description 2
- 230000002996 emotional effect Effects 0.000 description 2
- 238000000605 extraction Methods 0.000 description 2
- 230000010365 information processing Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 230000001537 neural effect Effects 0.000 description 2
- 230000002093 peripheral effect Effects 0.000 description 2
- 235000002020 sage Nutrition 0.000 description 2
- 230000003595 spectral effect Effects 0.000 description 2
- NPPQSCRMBWNHMW-UHFFFAOYSA-N Meprobamate Chemical compound NC(=O)OCC(C)(CCC)COC(N)=O NPPQSCRMBWNHMW-UHFFFAOYSA-N 0.000 description 1
- 241000233805 Phoenix Species 0.000 description 1
- VREFGVBLTWBCJP-UHFFFAOYSA-N alprazolam Chemical compound C12=CC(Cl)=CC=C2N2C(C)=NN=C2CN=C1C1=CC=CC=C1 VREFGVBLTWBCJP-UHFFFAOYSA-N 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 238000010420 art technique Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000007635 classification algorithm Methods 0.000 description 1
- 230000002860 competitive effect Effects 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 238000000802 evaporation-induced self-assembly Methods 0.000 description 1
- PCHJSUWPFVWCPO-UHFFFAOYSA-N gold Chemical compound [Au] PCHJSUWPFVWCPO-UHFFFAOYSA-N 0.000 description 1
- 239000010931 gold Substances 0.000 description 1
- 229910052737 gold Inorganic materials 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000007935 neutral effect Effects 0.000 description 1
- 230000008520 organization Effects 0.000 description 1
- 230000011273 social behavior Effects 0.000 description 1
- 239000000126 substance Substances 0.000 description 1
- 230000008685 targeting Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/042—Knowledge-based neural networks; Logical representations of neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0495—Quantised networks; Sparse networks; Compressed networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种基于多层感知机的图分类模型构建方法和图分类方法,前者包括:针对所需应用场景的分类任务获取若干图数据构成训练集;训练集至少部分图数据有类别标签;利用训练集对选定的图神经网络训练,将训练完成的图神经网络作为教师网络并保存图神经网络对应的分类结果;将选定的多层感知机作为学生模型,基于训练集、教师网络以及对应的分类结果,采用知识蒸馏的方法对学生模型训练,得到训练完成的学生模型作为图分类模型用于对所需应用场景的其余图数据分类。本发明的图分类模型结合多层感知机和图神经网络各自优点,在推理过程中无图的依赖性,保证较高精度的同时能大幅降低模型计算复杂度,提高推理速度,可用于时间受限的工程部署。
Description
技术领域
本发明属于人工智能领域,具体涉及一种基于多层感知机的图分类模型构建方法和图分类方法。
背景技术
图数据(Graph data)是现实世界中广泛存在的一种数据结构,它由节点和边构成,用于描述节点之间的关系或连接。例如,社交网络中的用户可以看作是节点,而他们之间的互动可以看作是边。图数据结构的复杂性和不规则性使得它蕴含了丰富的信息,但也给算法设计和实现带来了挑战。
近年来,对图数据的研究成为了热点方向。目前,针对图分类方法的研究主要有基于相似度计算的方法和基于图神经网络(Graph Neural Networks,GNNs)的方法。基于相似度计算的方法通过计算成对图的相似度对图进行分类,包括图核方法和图匹配方法。其中,图核方法主要通过图核的定义来计算图的相似度,它们共同的思想是将图分解为某种子结构,通过对比不同图上的子结构来计算图的相似度进而进行图分类。基于图匹配的方法则是考虑通过一些跨图的因素来计算图之间的相似度分数来进行图分类。图神经网络是一种针对图数据的深度学习算法,它可以通过学习节点和边的特征表示来进行分类任务;具体的,基于图神经网络的方法使用深度学习来建模图数据,利用图的结构信息和节点特征信息对图的特征进行提取并汇总得到整个图的表示用于分类。
目前,图神经网络在社交网络等不同领域都有广泛的应用。例如,在社交网络中,可以基于用户的行为、兴趣和特征进行分类,以帮助社交网络平台更好地了解用户需求,提供个性化的内容和服务。
然而,由于现实世界中的图数据规模庞大,图神经网络算法难以运行。这是因为图数据具有稀疏性和不规则性,导致算法的计算复杂度很高。因此,在大规模工业应用中,部署仍以多层感知机(Multilayer Perceptron,MLP)为主。多层感知机是一种经典的神经网络模型,它可以对输入数据进行高维特征提取和分类。多层感知机的优点是计算效率高,但缺点是无法利用图数据的结构信息,精度不如图神经网络。因此,亟需一种复杂度低且快速的图分类方法。
发明内容
为了解决现有技术中存在的上述问题,本发明提供了一种基于多层感知机的图分类模型构建方法、一种基于多层感知机的图分类方法、装置、电子设备和存储介质。本发明要解决的技术问题通过以下技术方案实现:
第一方面,本发明实施例提供了一种基于多层感知机的图分类模型构建方法,所述方法包括:
针对所需应用场景的分类任务,获取若干图数据构成训练集;其中,所述训练集中至少部分图数据带有类别标签;
利用所述训练集对选定的图神经网络进行训练,将训练完成的图神经网络作为教师网络,并保存所述图神经网络对应的分类结果;
将选定的多层感知机作为学生模型,基于所述训练集、所述教师网络以及对应的分类结果,采用知识蒸馏的方法对所述学生模型进行训练,得到训练完成的学生模型作为图分类模型,用于对所需应用场景的其余图数据进行分类。
在本发明的一个实施例中,所述所需应用场景,包括:
社交场景。
在本发明的一个实施例中,所述选定的图神经网络,包括:
GraphSAGE和GIN。
在本发明的一个实施例中,所述选定的多层感知机,包括:
多层的全连接网络。
在本发明的一个实施例中,所述基于所述训练集、所述教师网络以及对应的分类结果,采用知识蒸馏的方法对所述学生模型进行训练,包括:
由所述训练集中图数据的特征数据构成输入数据集;
基于所述输入数据集、所述输入数据集中数据带有的类别标签、所述图神经网络对应的分类结果以及预设的损失函数,采用知识蒸馏的方法对所述学生模型进行训练直至所述学生模型收敛,得到训练完成的学生模型。
在本发明的一个实施例中,所述预设的损失函数为:
其中,v表示图数据;V表示所述训练集;VL表示所述训练集中有类别标签的图数据构成的集合;表示图数据v对应的学生模型的分类结果;yv表示图数据v的类别标签;zv表示图数据v对应的教师模型的分类结果;Llabel表示真实的类别标签与学生模型的分类结果之间的损失;Lteacher表示学生模型的分类结果和教师模型的分类结果之间的损失;λ表示权重参数。
第二方面,本发明实施例提供了一种基于多层感知机的图分类方法,所述方法包括:
获取社交场景中待分类的目标图数据;
将所述目标图数据中的特征数据输入至预先训练完成的图分类模型中,得到对应的分类结果;其中,所述图分类模型第一方面所述的基于多层感知机的图分类模型构建方法得到。
第三方面,本发明实施例提供了一种基于多层感知机的图分类装置,所述装置包括:
图数据获取模块,用于获取社交场景中待分类的目标图数据;
分类模块,用于将所述目标图数据中的特征数据输入至预先训练完成的图分类模型中,得到对应的分类结果;其中,所述图分类模型根据第一方面所述的基于多层感知机的图分类模型构建方法得到。
第四方面,本发明实施例提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,所述处理器、所述通信接口、所述存储器通过所述通信总线完成相互间的通信;
所述存储器,用于存放计算机程序;
所述处理器,用于执行所述存储器上所存放的程序时,实现本发明实施例所提供的基于多层感知机的图分类方法的步骤。
第五方面,本发明实施例提供了一种计算机可读存储介质,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现本发明实施例所提供的基于多层感知机的图分类方法的步骤。
本发明的有益效果:图神经网络在处理复杂的图数据和结构时表现出色,但图结构的不规则性和动态性使得图神经网络在大规模图或动态图上的计算和优化更加复杂;而多层感知机模型在处理传统的欧几里得数据时更为高效和易用,但在处理复杂的图形关系和拓扑时能力有限,本发明所提供方案中的图分类模型,结合了多层感知机和图神经网络各自的优点,采用知识蒸馏的方法将图神经网络从图数据中提取到的拓扑信息、特征信息及其他有用信息传递给多层感知机,由于经过知识蒸馏方法训练得到的多层感知机用于分类,在推理过程中没有了图的依赖性,使其能够保持有竞争力的准确度,同时减少了图神经网络推理过程中的数据依赖问题,在保证较高精度的同时,大幅降低了模型的计算复杂度,提高了推理速度,因此可以像传统的神经网络一样进行预测和泛化。此外,通过离线的知识蒸馏,学生模型的参数得到了优化,可以像作为教师模型的图神经网络一样具有更高的精度,同时具有更快的推理速度,更容易部署到生产环境中,可用于时间受限的工程部署。
采用本发明实施例提供的图分类模型对社交场景中待分类的图数据进行分类,能够同时保证分类精度和分类速度。
附图说明
图1为本发明实施例所提供的一种基于多层感知机的图分类模型构建方法的流程示意图;
图2为图数据的数据结构示例图;
图3为本发明实施例所提供的一种基于多层感知机的图分类方法的流程示意图;
图4为本发明实施例所提供的一种基于多层感知机的图分类装置的结构示意图;
图5为本发明实施例所提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
目前,针对图分类方法的研究中,Borgwardt等人在文献“Borgwardt KM,KriegelHP.Shortest-path kernels on graphs.In:Proc.of the Fifth IEEE Int’l Conf.onData Mining(ICDM’05).Houston:IEEE,2005.8.[doi:10.1109/ICDM.2005.132]“中提出用弗洛伊德算法(Floyd-Warshall)得到最短路径图,相同长度的最短路径越多,图的相似度越高。Li等人在文献”Li YJ,Gu CJ,Dullien T,Vinyals O,Kohli P.Graph matchingnetworks for learning the similarity of graph structured objects.In:Proc.ofthe 36th Int’l Conf.on Machine Learning.California:PMLR,2019.3835-3845.“中提出了图匹配网络模型(GMN),该模型不仅考虑每个图内节点信息,也考虑跨图节点的信息,可以保证当两个图的匹配度较高时,学得的表示也更相似,相似度分数更高。Bai等人在文献”Bai YS,Ding H,Bian S,Chen T,Sun YZ,Wang W.Simgnn:Aneural network approachto fast graph similarity computation.In:Proc.of the 12th ACM Int’l Conf.onWeb Search and Data Mining.Phoenix:ACM,2019.384-392.[doi:10.1145/3289600.3290967]“中提出了快速计算图相似度的模型SimGNN。该模型不仅计算了跨图的节点相似度向量,而且计算了两个图表示间的相似度向量,进而利用这两个相似度向量计算得到两个图的相似度分数。Gilmer等人在文献”Gilmer J,Schoenholz SS,Riley PF,Vinyals O,Dahl GE.Neural message passing for quantum chemistry.In:Proc.of the34th Int’l Conf.on Machine Learning.Sydney:JMLR,2017.1263-1272.“中提出了信息传递式(MPNN)的通用框架,将卷积过程形式化为信息传递和节点信息更新两个函数。Xu等人在文献”Xu K,Hu W,Leskovec J,Jegelka S.How powerful are graph neuralnetworks?.In:Proc.of the Int’l Conf.on Learning Representations(ICLR).NewOrleans:OpenReview.net,2019.“中指出,基于1-WL的任何信息传递式的图卷积神经网络,其表达能力都不会强于1-WL,而1-WL本身的表达能力有限,对于许多非同构图并不能准确地分类。Ma等人在文献”Ma Y,Wang SH,Aggarwal CC,Tang JL.Graph convolutionalnetworks with eigenpooling.In:Proc.of the 25th ACM SIGKDDInt’lConf.onKnowledge Discovery&Data Mining.Singapore:ACM,2019.723-731.[doi:10.1145/3292500.3330982]“中提出谱池化(EigenPooling)模型,利用谱聚类得到硬分配矩阵,并根据该矩阵进行子图划分。Ying等人在文献”Ying R,You JX,Morris C,Ren X,Hamilton WL,Leskovec J.Hierarchical graph representation learning with differentiablepooling.In:Proc.of the 32nd Int’l Conf.on Neural Information ProcessingSystems.Montréal:Curran Associates Inc.,2018.4805-4815.中提出的可微分池化(DiffPool)模型通过参数学习的方式得到软聚类分配矩阵来进行图分类。
由于图数据的规模往往非常庞大,包含的节点和边数量都非常多,因此处理大规模图数据是一个非常棘手的问题。目前的图分类技术对于大规模图数据的处理能力仍然受到限制,现有技术通常需要进行大量的计算操作,例如相似度计算、节点特征提取、图卷积、图注意力等。这些计算操作的复杂度往往较高,导致图分类技术的训练和推理速度较慢,这对于延迟受限的应用来说是无法忍受的。
而且,目前在大规模工业应用中,部署仍以多层感知机为主。多层感知机虽然计算效率高,但无法利用图数据的结构信息,精度不如图神经网络。因此,若能够提高多层感知机的精度使其能够进行图数据分类将具有重要的意义。在此基础上,本发明提供了一种基于多层感知机的图分类模型构建方法,能够用于所需应用场景的分类任务构建并训练得到图分类模型。并且针对社交这一特定的应用场景,提出了一种基于多层感知机的图分类方法、装置、电子设备和存储介质。以下对本发明实施例方案进行具体说明。
第一方面,本发明实施例提供了一种基于多层感知机的图分类模型构建方法,如图1所示,该方法可以包括如下步骤:
S1,针对所需应用场景的分类任务,获取若干图数据构成训练集;
其中,所述训练集中至少部分图数据带有类别标签;
图数据是一种表示对象之间关系的数据结构,它由节点(Nodes)和边(Edges)组成,节点表示对象,边表示对象之间的关系,如图2所示。节点可以表示各种实体,如人、物体、事件等,而边则表示节点之间的连接或关联。由于应用场景和需求的不同,图数据可以用于描述各种复杂的关系网络。本发明实施例中,应用场景可以包括社交场景,对应的关系网络即为社交网络,相应的,图数据的分类任务可以是:
社区分类:图分类可以用于识别社交网络中的社区或群体。社区分类旨在将网络中相互连接的节点划分为具有相似属性或共同兴趣的群体。通过图分类算法,可以将网络中的节点划分为不同的社区,从而帮助人们理解社交网络的组织结构和用户之间的关系。
用户分类:图分类可以用于对社交网络中的用户进行分类。通过分析用户在社交网络中的行为、关系和兴趣等信息,可以将用户分为不同的类别或群体。这对于个性化推荐、社交广告定向和用户关系分析等任务非常有用。
情感分类:图分类可以用于对社交网络中的文本或用户发表的内容进行情感分类。通过分析用户在社交网络中发布的帖子、评论或推文等文本信息,可以判断其情感倾向,例如正面、负面或中性。这对于舆情分析、情感监测和品牌声誉管理等方面具有重要意义。
当然,本发明实施例图数据可以描述的关系网络不限于社交网络,还可以包括交通网络、知识图谱、交通网络、分子结构、生物化学网络等。图数据能够捕捉对象之间的相互作用和依赖关系,提供了一种丰富的数据表示形式。在图数据中,节点和边的属性可以包含各种信息。节点属性描述了节点的特征,如人的年龄、物体的颜色等。边属性描述了节点之间的关系特征,例如社交网络中的好友关系、分子结构中的化学键等。
本发明实施例针对所需应用场景的分类任务,需要针对性地获取若干图数据,并可以获取每个图数据的类别标签,即图数据真实的类别标签,连同获取的无类别标签的图数据一起构成训练集。
S2,利用所述训练集对选定的图神经网络进行训练,将训练完成的图神经网络作为教师网络,并保存所述图神经网络对应的分类结果;
本发明实施例中,根据所述训练集中数据的类型、规模和特征等,选择合适的图神经网络和多层感知机。
可选的一种实施方式中,所述选定的图神经网络,可以包括:
GraphSAGE和GIN。
本领域技术人员可以理解的是,GraphSAGE(SAmple and aggreGatE)和GIN(GraphIsomorphism Network,图同构网络)是两种现有的图神经网络。当然,本发明实施例的图神经网络不限于此。
S2中,利用所述训练集对选定的图神经网络进行训练,可以采用现有的图神经网络训练方法实现,在此不做详细说明。训练结束后,可以得到训练完成的、针对所需应用场景分类任务下的图数据,具有良好分类性能的图神经网络,训练完成的图神经网络能够通过提取出数据中的有用特征,包括图的拓扑结构信息和特征信息等实现图数据准确分类。
并且,S2中可以保留该图神经网络对训练所述训练集得到的分类结果,包括所述训练集中每个图数据属于各个类别的置信度、图神经网络的中间特征,以及经过训练得到的参数等,以供后续的知识蒸馏过程使用。
S3,将选定的多层感知机作为学生模型,基于所述训练集、所述教师网络以及对应的分类结果,采用知识蒸馏的方法对所述学生模型进行训练,得到训练完成的学生模型作为图分类模型,用于对所需应用场景的其余图数据进行分类。
可选的,本发明实施例中,所述选定的多层感知机,可以包括:
多层的全连接网络。
当然,本发明实施例选定的多层感知机不限于此,比如也可以是相关的轻量化神经网络等,在此不做限制。
本领域技术人员可以理解的是,“教师模型”指能够为待训练的模型提供指导信息的完备模型,通常为结构复杂度高且性能较好的模型;与之对应的为“学生模型”,通常为复杂度低,但性能相对较差的模型,本发明实施例中将训练完成的图神经网络作为教师模型,将选定的多层感知机作为学生模型。“知识蒸馏”指利用“教师模型”对复杂数据的处理与特征提取能力,将“教师模型”的关键信息“蒸馏”到更小、更高效的“学生模型”中,以对“学生模型”的训练与学习进行指导,包括其预测和内部特征表示,从而达到与更大模型类似的性能,但以更少的计算资源和更快的处理速度为代价。
在本发明实施例中,通过图神经网络对多层感知机的知识蒸馏过程,使得小型的多层感知机学习到大型图神经网络中的知识,提高多层感知机对图数据的表示能力与分类性能。经过蒸馏训练,学生模型自己从输入的数据集中学习特征相关的知识,同时教师模型会将其学习到的图的拓扑结构信息、特征信息及其他有用信息“传递”给学生模型。
具体的,S3中,所述基于所述训练集、所述教师网络以及对应的分类结果,采用知识蒸馏的方法对所述学生模型进行训练,可以包括:
1)由所述训练集中图数据的特征数据构成输入数据集;
本领域技术人员可以理解的是,图数据可以包括结构数据和特征数据,特征数据表示节点对应的特征。
本发明实施例针对所述学生模型的数据仍是来源于训练图神经网络所采用的训练集,但仅仅采用结构数据。
2)基于所述输入数据集、所述输入数据集中数据带有的类别标签、所述图神经网络对应的分类结果以及预设的损失函数,采用知识蒸馏的方法对所述学生模型进行训练直至所述学生模型收敛,得到训练完成的学生模型。
该步骤的训练过程可以参见现有的知识蒸馏训练过程理解,针对该步骤,是将同一个训练数据在传递“教师模型”节点之间的关系和相互作用信息作为知识传递给“学生模型”,由于教师模型能够提取出拓扑信息、特征信息等有用信息,而学生模型的输入只有节点的特征信息,不包含图的拓扑结构信息,因此,学生模型没有了图的依赖性,能够不必利用图数据的结构信息,但仍旧能够利用作为教师模型的图神经网络传递的有用信息,因此能够提高多层感知机的精度,使其具有更快的推理速度。
其中,可以理解的是,训练过程需要利用损失函数完成,本发明实施例中,所述预设的损失函数为:
其中,v表示图数据;V表示所述训练集;VL表示所述训练集中有类别标签的图数据构成的集合;表示图数据v对应的学生模型的分类结果;yv表示图数据v的类别标签;zv表示图数据v对应的教师模型的分类结果;Llabel表示真实的类别标签与学生模型的分类结果之间的损失,可以采用交叉熵等常见的损失函数,用来保证学生模型的分类结果与真实类别标签的一致性;Lteacher表示学生模型的分类结果和教师模型的分类结果之间的损失,可以采用KL散度等方法,用来保证学生模型的分类结果与图神经网络提取的特征之间的一致性;λ表示权重参数,可以用来调节两部分损失的比例。
在上述训练过程中,通过反向传播算法和预设的优化方法,比如随机梯度下降法等,更新所述学生模型的参数,通过持续优化,训练所述学生模型,直至所述学生模型收敛,则完成学生模型的训练。
关于基于所述训练集、所述教师网络以及对应的分类结果,采用知识蒸馏的方法对所述学生模型进行训练的具体过程,请参见相关技术,在此不做详细说明。
通过上述步骤S1~S3,可以得到训练完成的学生模型作为图分类模型,之后,可选的,在此之后,所述基于多层感知机的图分类模型构建方法还可以包括:
将所述图分类模型部署至所需应用场景中的计算设备中。
其中,所述计算设备包括计算机、服务器等具有计算功能的设备。
可选的,可以将所述图分类模型部署至实际生产环境中的特定设备中,所述特定设备可以兼具图数据采集功能和分类计算功能,比如,所述特定设备可以为路由器、传感器等嵌入式设备,从而为实际生产环境中的图数据,输出分类结果。
本发明实施例中针对某一应用场景,构建得到图分类模型后,若在实际的图数据分类过程中,输入的图数据与该应用场景原先的图数据相比差异较大,则将其视为新增数据。在新增数据占当前输入图数据的比例超过预设比例(如20%等)时,则目前的图分类模型无法保证对新增数据的良好分类结果,其精度可能会下降,则需要利用新增数据对所述分类模型重新训练,即利用包含新增数据的所有图数据再次执行S1~S3步骤,得到新的分类模型。
本发明实施例所提供的基于多层感知机的图分类模型构建方法中,首先针对所需应用场景的分类任务,获取带有类别标签的若干图数据构成训练集;然后利用所述训练集对选定的图神经网络进行训练,将训练完成的图神经网络作为教师网络,并保存所述图神经网络对应的分类结果;最后将选定的多层感知机作为学生模型,基于所述训练集、所述教师网络以及对应的分类结果,采用知识蒸馏的方法对所述学生模型进行训练,得到训练完成的学生模型作为图分类模型,用于对所需应用场景的其余图数据进行分类。图神经网络在处理复杂的图数据和结构时表现出色,但图结构的不规则性和动态性使得图神经网络在大规模图或动态图上的计算和优化更加复杂;而多层感知机模型在处理传统的欧几里得数据时更为高效和易用,但在处理复杂的图形关系和拓扑时能力有限,本发明方案结合了多层感知机和图神经网络各自的优点,采用知识蒸馏的方法将图神经网络从图数据中提取到的拓扑信息、特征信息及其他有用信息传递给多层感知机,由于经过知识蒸馏方法训练得到的多层感知机,在推理过程中没有了图的依赖性,使其能够保持有竞争力的准确度,同时减少了图神经网络推理过程中的数据依赖问题,在保证较高精度的同时,大幅降低了模型的计算复杂度,提高了推理速度,因此可以像传统的神经网络一样进行预测和泛化。此外,通过离线的知识蒸馏,学生模型的参数得到了优化,可以像作为教师模型的图神经网络一样具有更高的精度,同时具有更快的推理速度,更容易部署到生产环境中,可用于时间受限的工程部署。
为了便于理解本发明方法的构思和实现的效果,以下通过将其与相关现有技术进行对比来简要说明。
现有技术中,采用类似方法的研究大都集中在节点分类任务上。节点分类的重点是对图中的每个节点进行分类,主要关注如何对节点的特征和邻居节点的信息进行聚合和传递,以生成节点级别的表示。而图分类则需要对整个图进行分类,这就需要考虑如何对整个图的结构和特征进行汇聚和整合,以生成图级别的表示。
具体的,对于本发明实施例中的教师模型即图神经网络而言,节点分类和图分类这两类任务的数据都可以直接送入到网络模型中进行训练。而对于本发明实施例中的学生模型即多层感知机来说,由于其输入数据只有特征数据,且输入输出的维度固定,所以这两类任务在学生模型上的处理有所不同:在节点分类任务中,每个节点都有独立的标签,因此可以将输入数据以节点为粒度进行划分,并直接将数据送入学生模型进行分批训练;在图分类任务中,输入数据以图为粒度进行划分,而不同图之间的大小可能不同,这导致无法确定模型的参数,所以不能直接将数据送入学生模型进行分批训练。此外,在节点分类任务中,可以直接使用节点的标签信息作为教师模型和学生模型之间的目标,进行知识蒸馏的训练。但在图分类任务中,没有直接的图级标签可用于蒸馏目标的定义。
为了克服上述困难,本发明在学生模型的训练过程中引入了聚合操作,如求和、取均值等方法,通过这种方法来捕获图的全局信息,生成图级别的表示,通过设计的损失函数,将聚合后的图表示数据经过反向传播操作后训练模型。这样,经过聚合操作后,可以将整个图类比为一个节点,使得其后续的处理与节点分类任务类似。
以下通过实验,对本发明实施例构建的图分类模型的效果进行说明。
1.实验条件
本发明是在中央处理器为Intel(R)Xeon(R)Gold 6226R CPU@2.90GHz、NVIDIAGeForce RTX 3090、Ubuntu 20.04.1操作系统上,使用深度学习框架PyTorch进行试验。实验所用数据来自TUDatasets数据集,该数据集是一个广泛使用的图数据集合,包含多个不同的图分类和图回归任务。
实验中对比的方法如下:
一种是基于注意力机制的图神经网络,实验中记录为GAT,参考文献为P,Cucurull G,Casanova A,et al.Graph attention networks[J].arXiv preprintarXiv:1710.10903,2017。
另一种是可对未知节点起到泛化作用的图神经网络,实验中记为SAGE,参考文献为Hamilton W,Ying Z,Leskovec J.Inductive representation learning on largegraphs[J].Advances in neural information processing systems,2017,30.
其中MLP为未经过知识蒸馏的多层感知机,而本发明则为经过知识蒸馏的多层感知机。
2.实验内容
根据前文所述内容得到图分类模型,计算图分类任务的准确率和推理时间,并与GAT和SAGE方法的准确率和推理时间进行比较,结果如表1所示。
表1实验结果对比
从表1可见,通过知识蒸馏的方式来对多层感知机进行训练,可以使得学生模型的准确度与教师模型相当,同时推理时间减少,验证了本发明的有效性。
综上,本发明通过知识蒸馏的方法使多层感知机的精度提高,而推理时间相比图神经网络大幅降低,能够提高图分类任务的推理速度,能够解决由于工业规模的图太大导致图神经网络无法部署在延迟受限应用中的问题,可以更好地适应实际的工业应用场景。
第二方面,本发明实施例提供了一种基于多层感知机的图分类方法,如图3所示,该方法可以包括如下步骤:
S01,获取社交网络中待分类的目标图数据;
针对目标图数据的结构数据,图的结构数据表示了社交网络中的节点和它们之间的连接关系。节点可以表示用户、内容或事件,边表示节点之间的关系,例如好友关系、关注关系、转发关系等。
针对目标图数据的特征数据,节点特征可以包括用户的个人属性,如性别、年龄、地区等、用户的社交行为,如关注数、粉丝数等,或内容的属性,如文本特征、图片特征等。边特征可以包括社交关系的强度、互动频率等。
针对目标图数据的分类任务和分类结果中的类别取决于具体的应用场景和需求。比如分类任务可以包括但不限于以下所示:
1、社区分类:将社交网络中的用户或节点划分为不同的社区或群体。类别可以是不同的兴趣群体、社交圈子或用户类型。
2、用户行为分类:对社交网络中的用户进行行为分类,例如活跃用户、沉默用户、潜在用户等。
3、内容分类:将社交网络中的内容(如帖子、评论、推文等)进行分类,例如正面内容、负面内容、广告内容、媒体报道等。
S02,将所述目标图数据中的特征数据输入至预先训练完成的图分类模型中,得到对应的分类结果;
其中,所述图分类模型根据第一方面提供的基于多层感知机的图分类模型构建方法得到。关于图分类模型的相关内容请参见前文描述,在此不做赘述。
本发明实施例所提供的基于多层感知机的图分类方法基于所得到的图分类模型实现,所述图分类模型是利用作为教师模型的训练完成的图神经网络,经过知识蒸馏方法训练多层感知机得到;所述图分类模型在对社交网络中待分类的图数据进行推理过程时没有图的依赖性,能够保持有竞争力的准确度,同时能够减少图神经网络推理过程中的数据依赖问题,在保证较高精度的同时,能够大幅降低模型的计算复杂度,提高推理速度,同时具有更快的推理速度。
第三方面,相应于上述方法实施例,本发明实施例还提供了一种基于多层感知机的图分类装置,如图4所示,该装置包括:
图数据获取模块401,用于获取社交网络中待分类的目标图数据;
分类模块402,用于将所述目标图数据中的特征数据输入至预先训练完成的图分类模型中,得到对应的分类结果;其中,所述图分类模型根据第一方面提供的基于多层感知机的图分类模型构建方法得到。
关于该装置各个模块的具体处理过程请参见第二方面的相关内容,在此不做赘述。
第四方面,本发明实施例还提供了一种电子设备,如图5所示,包括处理器501、通信接口502、存储器503和通信总线504,其中,处理器501、通信接口502、存储器503通过通信总线504完成相互间的通信,
所述存储器,用于存放计算机程序;
所述处理器,用于执行所述存储器上所存放的程序时,实现本发明实施例第二方面所提供的基于多层感知机的图分类方法的步骤。
上述电子设备提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口用于上述电子设备与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non-Volatile Memory,NVM),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(Network Processor,NP)等;还可以是数字信号处理器(Digital SignalProcessing,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
本发明实施例提供的基于多层感知机的图分类方法可以应用于电子设备。具体的,该电子设备可以为:台式计算机、便携式计算机、智能移动终端、服务器等。在此不作限定,任何可以实现本发明的电子设备,均属于本发明的保护范围。
第五方面,相应于第二方面所提供的基于多层感知机的图分类方法,本发明实施例还提供了一种计算机可读存储介质,该计算机可读存储介质内存储有计算机程序,计算机程序被处理器执行时实现本发明实施例第二方面所提供的基于多层感知机的图分类方法的步骤。
对于装置/电子设备/存储介质实施例而言,由于其基本相似于对应的方法实施例,所以描述的比较简单,相关之处参见对应方法实施例的部分说明即可。
以上所述仅为本发明的较佳实施例而已,并非用于限定本发明的保护范围。凡在本发明的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本发明的保护范围内。
Claims (10)
1.一种基于多层感知机的图分类模型构建方法,其特征在于,包括:
针对所需应用场景的分类任务,获取若干图数据构成训练集;其中,所述训练集中至少部分图数据带有类别标签;
利用所述训练集对选定的图神经网络进行训练,将训练完成的图神经网络作为教师网络,并保存所述图神经网络对应的分类结果;
将选定的多层感知机作为学生模型,基于所述训练集、所述教师网络以及对应的分类结果,采用知识蒸馏的方法对所述学生模型进行训练,得到训练完成的学生模型作为图分类模型,用于对所需应用场景的其余图数据进行分类。
2.根据权利要求1所述的基于多层感知机的图分类模型构建方法,其特征在于,所述所需应用场景,包括:
社交场景。
3.根据权利要求1所述的基于多层感知机的图分类模型构建方法,其特征在于,所述选定的图神经网络,包括:
GraphSAGE和GIN。
4.根据权利要求1所述的基于多层感知机的图分类模型构建方法,其特征在于,所述选定的多层感知机,包括:
多层的全连接网络。
5.根据权利要求1所述的基于多层感知机的图分类模型构建方法,其特征在于,所述基于所述训练集、所述教师网络以及对应的分类结果,采用知识蒸馏的方法对所述学生模型进行训练,包括:
由所述训练集中图数据的特征数据构成输入数据集;
基于所述输入数据集、所述输入数据集中数据带有的类别标签、所述图神经网络对应的分类结果以及预设的损失函数,采用知识蒸馏的方法对所述学生模型进行训练直至所述学生模型收敛,得到训练完成的学生模型。
6.根据权利要求5所述的基于多层感知机的图分类模型构建方法,其特征在于,所述预设的损失函数为:
其中,v表示图数据;V表示所述训练集;VL表示所述训练集中有类别标签的图数据构成的集合;表示图数据v对应的学生模型的分类结果;yv表示图数据v的类别标签;zv表示图数据v对应的教师模型的分类结果;Llabel表示真实的类别标签与学生模型的分类结果之间的损失;Lteacher表示学生模型的分类结果和教师模型的分类结果之间的损失;λ表示权重参数。
7.一种基于多层感知机的图分类方法,其特征在于,包括:
获取社交场景中待分类的目标图数据;
将所述目标图数据中的特征数据输入至预先训练完成的图分类模型中,得到对应的分类结果;其中,所述图分类模型根据权利要求1-6任一项所述的基于多层感知机的图分类模型构建方法得到。
8.一种基于多层感知机的图分类装置,其特征在于,包括:
图数据获取模块,用于获取社交场景中待分类的目标图数据;
分类模块,用于将所述目标图数据中的特征数据输入至预先训练完成的图分类模型中,得到对应的分类结果;其中,所述图分类模型根据权利要求1-6任一项所述的基于多层感知机的图分类模型构建方法得到。
9.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,所述处理器、所述通信接口、所述存储器通过所述通信总线完成相互间的通信;
所述存储器,用于存放计算机程序;
所述处理器,用于执行所述存储器上所存放的程序时,实现权利要求1-7任一所述的方法步骤。
10.一种计算机可读存储介质,其特征在于,
所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现权利要求1-7任一所述的方法步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311423387.7A CN117473315A (zh) | 2023-10-30 | 2023-10-30 | 一种基于多层感知机的图分类模型构建方法和图分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311423387.7A CN117473315A (zh) | 2023-10-30 | 2023-10-30 | 一种基于多层感知机的图分类模型构建方法和图分类方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117473315A true CN117473315A (zh) | 2024-01-30 |
Family
ID=89639094
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311423387.7A Pending CN117473315A (zh) | 2023-10-30 | 2023-10-30 | 一种基于多层感知机的图分类模型构建方法和图分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117473315A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117829320A (zh) * | 2024-03-05 | 2024-04-05 | 中国海洋大学 | 一种基于图神经网络和双向深度知识蒸馏的联邦学习方法 |
-
2023
- 2023-10-30 CN CN202311423387.7A patent/CN117473315A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117829320A (zh) * | 2024-03-05 | 2024-04-05 | 中国海洋大学 | 一种基于图神经网络和双向深度知识蒸馏的联邦学习方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
You et al. | Cross-modality attention with semantic graph embedding for multi-label classification | |
Tingting et al. | Three‐stage network for age estimation | |
Kan et al. | Supervised deep feature embedding with handcrafted feature | |
CN111582409B (zh) | 图像标签分类网络的训练方法、图像标签分类方法及设备 | |
Xiaomei et al. | Microblog sentiment analysis with weak dependency connections | |
Jiang et al. | A unified multiple graph learning and convolutional network model for co-saliency estimation | |
CN109063719B (zh) | 一种联合结构相似性和类信息的图像分类方法 | |
WO2023065859A1 (zh) | 物品推荐方法、装置及存储介质 | |
Zhang et al. | A triple wing harmonium model for movie recommendation | |
Jiang et al. | Multiple graph convolutional networks for co-saliency detection | |
CN112016601A (zh) | 基于知识图谱增强小样本视觉分类的网络模型构建方法 | |
US20240037750A1 (en) | Generating improved panoptic segmented digital images based on panoptic segmentation neural networks that utilize exemplar unknown object classes | |
CN113065974A (zh) | 一种基于动态网络表示学习的链路预测方法 | |
CN117473315A (zh) | 一种基于多层感知机的图分类模型构建方法和图分类方法 | |
CN115293919B (zh) | 面向社交网络分布外泛化的图神经网络预测方法及系统 | |
Luo et al. | BCMM: A novel post-based augmentation representation for early rumour detection on social media | |
Chen et al. | Geoconv: Geodesic guided convolution for facial action unit recognition | |
Henríquez et al. | Twitter sentiment classification based on deep random vector functional link | |
Song et al. | Gratis: Deep learning graph representation with task-specific topology and multi-dimensional edge features | |
Shukla et al. | Role of hybrid optimization in improving performance of sentiment classification system | |
Gu et al. | Towards facial expression recognition in the wild via noise-tolerant network | |
Zhang et al. | Crowdnas: A crowd-guided neural architecture searching approach to disaster damage assessment | |
Chu et al. | A novel recommender system for E-commerce | |
Bin et al. | Combining multi-representation for multimedia event detection using co-training | |
Wang et al. | Simultaneously discovering and localizing common objects in wild images |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |