CN115907001B - 基于知识蒸馏的联邦图学习方法及自动驾驶方法 - Google Patents

基于知识蒸馏的联邦图学习方法及自动驾驶方法 Download PDF

Info

Publication number
CN115907001B
CN115907001B CN202211415148.2A CN202211415148A CN115907001B CN 115907001 B CN115907001 B CN 115907001B CN 202211415148 A CN202211415148 A CN 202211415148A CN 115907001 B CN115907001 B CN 115907001B
Authority
CN
China
Prior art keywords
model
graph
client
distillation
local
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202211415148.2A
Other languages
English (en)
Other versions
CN115907001A (zh
Inventor
鲁鸣鸣
肖智祥
王诗雨
谢家豪
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Central South University
Original Assignee
Central South University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Central South University filed Critical Central South University
Priority to CN202211415148.2A priority Critical patent/CN115907001B/zh
Publication of CN115907001A publication Critical patent/CN115907001A/zh
Application granted granted Critical
Publication of CN115907001B publication Critical patent/CN115907001B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Feedback Control In General (AREA)

Abstract

本发明公开了一种基于知识蒸馏的联邦图学习方法,包括在服务器上训练图神经网络得到教师模型;客户端获取教师模型;客户端采用教师模型训练自身的图神经网络模型得到本地模型;客户端训练本地模型并微调教师模型;服务器采用微调后的教师模型对服务器上的教师模型进行更新;重复以上步骤直至教师模型和本地模型更新完毕,完成基于知识蒸馏的联邦图学习。本发明还公开了一种包括所述基于知识蒸馏的联邦图学习方法的自动驾驶方法。本发明解决了现实场景下的数据孤岛问题和模型异质性问题,而且可靠性高,准确性好。

Description

基于知识蒸馏的联邦图学习方法及自动驾驶方法
技术领域
本发明属于人工智能技术领域,具体涉及一种基于知识蒸馏的联邦图学习方法及自动驾驶方法。
背景技术
随着经济技术的发展和人们生活水平的提高,人工智能技术已经广泛应用于人们的生产和生活当中,给人们的生产和生活带来了无尽的便利。图神经网络(Graph NeuralNetwork,GNN)具有强大的空间特征提取能力,在生物分子领域和社区预测等图分类领域有着非常优秀的表现,因此也受到了研究人员的广泛关注。
训练一个优秀的GNN模型需要大量的图样本数据;但是,在现实生活中,大量的图数据往往掌握在不同的部门或机构手中,而且由于隐私问题、商业竞争和法律法规等的限制,人们并无法集中不同部门或机构的图样本数据来训练一个集中式的GNN模型。为了解决这个数据孤岛问题,有研究人员从联邦学习的角度出发,提出了联邦图学习。联邦图学习通过参数或特征共享的方式支持多方协同训练公共模型,从而解决图神经网络上的数据孤岛的问题;这也是一种较好的分布式学习范式。但同普通联邦学习一样,联邦图学习同样存在着异质性的问题。
目前有研究人员针对联邦学习模型的异质性,提出了相应的解决方案,如FedMD模型。FedMD模型在客户端设置模型时仅考虑了修改网络层或者神经单元数量两个方面;但GNN模型根据算子的不同有多种类别,因此在实际应用中每个客户端可以采取不同的GNN模型,但FedMD模型并不支持这一点。另一方面,客户端之间计算能力也存在差异性,通常GNN的运算量随着单个图样本的数据规模而相应增加,在客户端存在硬件瓶颈限制的情形下,如何将强大的GNN模型学习到的知识,转移到无法应用GNN模型的客户端上,也是亟待解决的问题。
图神经网络在自动驾驶领域有着较好的应用;其能够在自动驾驶的过程中准确检测到车辆、行人的物体,并以此为自动驾驶的控制提供数据支撑。但是,正因为上述GNN模型学习的原因,使得GNN图神经网络在自动驾驶过程的应用也存在可靠性不高、准确性不好的问题。
发明内容
本发明的目的之一在于提供一种能够解决数据孤岛问题和模型异质性问题,而且可靠性高,准确性好的基于知识蒸馏的联邦图学习方法。
本发明的目的之二在于提供一种包括了所述基于知识蒸馏的联邦图学习方法的自动驾驶方法。
本发明提供的这种基于知识蒸馏的联邦图学习方法,包括如下步骤:
S1.在服务器上,通过公共数据集对图神经网络进行训练,得到教师模型;
S2.各个客户端从服务器上获取教师模型;
S3.基于知识蒸馏框架和本地私有数据,各个客户端采用步骤S2获取的教师模型对自身的图神经网络模型进行训练,得到各个客户端自身的本地模型;
S4.基于微调框架和本地私有数据,各个客户端对自身的本地模型进行训练,同时对获取的教师模型进行微调;
S5.基于联邦聚合算法,服务器采用各个客户端的微调后的教师模型对服务器的教师模型进行更新;
S6.重复步骤S2~S5直至设定的条件,结束教师模型和本地模型的更新,完成基于知识蒸馏的联邦图学习。
步骤S1所述的图神经网络,具体包括如下内容:
图神经网络为多层图同构神经网络;多层图同构神经网络能够将图神经网络中每个节点的邻居通过邻域聚合的方式汇聚局部空间的结构信息,然后再与自身节点的特征进行比例混合,最后采用能够映射到任意特征空间的全连接层来保证网络的单射特性,从而最大化图模型的表达能力;
采用如下算式作为多层图同构神经网络的计算式:
Figure BDA0003938623570000031
式中
Figure BDA0003938623570000032
为节点v在k层的隐藏特征;MLP(k)()为多层感知器;ε(k)为待学习的参数或者为一个设定的固定值,用于控制节点v特征在GNN迭代中的影响大小;u为邻居节点;N(v)为邻居节点的集合;/>
Figure BDA0003938623570000033
为邻居节点u在(k-1)层的隐藏特征;
同时,在多层图同构神经网络之间增加PairNorm操作,从而解决GNN随着网络层数增加而造成的过平滑问题;采用如下算式作为PairNorm操作的计算式:
Figure BDA0003938623570000034
Figure BDA0003938623570000035
式中
Figure BDA0003938623570000036
为中心化特征矩阵;/>
Figure BDA0003938623570000037
为节点表示矩阵;n为节点的数量;/>
Figure BDA0003938623570000038
为总的配对平方距离;s为用来控制缩放范围的超参数;/>
Figure BDA0003938623570000041
为2-范数的平方;算式/>
Figure BDA0003938623570000042
表示去中心化,用于对每一行的特征矩阵减去对应行特征向量的均值,对整体的数值进行中心化处理;算式/>
Figure BDA0003938623570000043
用于对特征矩阵进行重新缩放操作,让整体的节点的嵌入向量不再趋于一致,让整体节点之间的特征向量嵌入更加符合真实情况的节点分布,从而缓解随着网络层数增加造成的过平滑所带来的性能下降。
步骤S3所述的基于知识蒸馏框架和本地私有数据,各个客户端采用步骤S2获取的教师模型对自身的图神经网络模型进行训练,具体包括如下步骤:
若客户端采用的客户端模型包含GNN层,则采用H-KD蒸馏方法进行有图知识蒸馏;
若客户端采用的客户端模型为MLP模型,则进行无图知识蒸馏。
所述的采用H-KD蒸馏方法进行有图知识蒸馏,具体包括如下步骤:
在H-KD蒸馏方法中,本地模型需要学习的知识包括全连接层的输出和最终预测类别的软标签;采用如下算式计算软标签:
Figure BDA0003938623570000044
式中pi(zi,T)为第i类的类概率;zi为全连接层输出值z的第i维值;T为温度,用于控制软标签的重要性;k为模型预测类别的数量;
在H-KD蒸馏方法中,整体损失函数包括蒸馏损失和交叉熵损失;
采用如下算式计算蒸馏损失LD
Figure BDA0003938623570000051
式中pi(ti,T)为教师模型输出的软标签;ti为教师模型输出;si为学生模型的输出;pi(si,T)为学生模型输出的软标签;
采用如下算式计算交叉熵损失LS
Figure BDA0003938623570000052
式中yi为真实标签;pi(si,1)为温度T为1时的学生模型输出的软标签;
采用如下算式作为整体损失函数Ltotal
Ltotal=λLD+(1-λ)LS
式中λ为权重值。
所述的进行无图知识蒸馏,具体包括如下步骤:
本地模型为MLP模型;
训练时的损失函数包括蒸馏损失L'D和交叉熵损失L'S
计算蒸馏损失L'D
Figure BDA0003938623570000053
计算交叉熵损失L'S
Figure BDA0003938623570000054
采用如下算式作为整体损失函数Ltotal
Figure BDA0003938623570000055
式中λ'为权重值;V为所有节点;VL为带标签的节点。
步骤S5所述的联邦聚合算法,具体为FedAvg算法。
本发明还公开了一种包括所述基于知识蒸馏的联邦图学习方法的自动驾驶方法,具体包括如下步骤:
A.确定服务器上的初始教师模型和自动驾驶车辆上的初始本地模型;
B.采用所述的基于知识蒸馏的联邦图学习方法,进行初始教师模型和初始本地模型的学习和更新,得到最终的本地模型;
C.自动驾驶车辆采用步骤B得到的本地模型,在自动驾驶过程中进行周围环境中的物体进行识别;
D.根据步骤C的识别结果,对车辆进行控制,完成车辆的自动驾驶。
本发明提供的这种基于知识蒸馏的联邦图学习方法及自动驾驶方法,研究了联邦图学习场景下的模型异质性问题,考虑到不同客户端的算力不同,允许各个客户端有不同的本地模型,本发明通过知识蒸馏的框架,利用经过公共数据集和各客户端私有数据集训练的教师模型来指导各客户端的本地模型;在本发明方法下,各客户端甚至可以使用浅层的GNN模型或者MLP模型就可以达到较好的训练效果;因此本发明解决了现实场景下的数据孤岛问题和模型异质性问题,而且可靠性高,准确性好。
附图说明
图1为本发明学习方法的方法流程示意图。
图2为本发明自动驾驶方法的方法流程示意图。
具体实施方式
如图1所示为本发明学习方法的方法流程示意图:本发明提供的这种基于知识蒸馏的联邦图学习方法,包括如下步骤:
S1.在服务器上,通过公共数据集对图神经网络进行训练,得到教师模型;
具体实施时,图神经网络具体包括如下内容:
图神经网络为多层图同构神经网络(D-GNN);多层图同构神经网络能够将图神经网络中每个节点的邻居通过邻域聚合的方式汇聚局部空间的结构信息,然后再与自身节点的特征进行比例混合,最后采用能够映射到任意特征空间的全连接层来保证网络的单射特性,从而最大化图模型的表达能力;
采用如下算式作为多层图同构神经网络的计算式:
Figure BDA0003938623570000071
式中
Figure BDA0003938623570000072
为节点v在k层的隐藏特征;MLP(k)()为多层感知器;ε(k)为可学习的参数或者是一个固定值,用来控制节点v特征在GNN迭代中的影响大小;u为邻居节点;N(v)为邻居节点的集合;/>
Figure BDA0003938623570000073
为邻居节点u在(k-1)层的隐藏特征;
通常GNN模型在网络层加深时会存在过平滑的问题,随着网络层数的增加,节点特征趋向于收敛到相同或相似的向量;因此,在多层图同构神经网络之间增加PairNorm操作,从而解决GNN随着网络层数增加而造成的过平滑问题;采用如下算式作为PairNorm操作的计算式:
Figure BDA0003938623570000074
Figure BDA0003938623570000075
式中
Figure BDA0003938623570000076
为中心化特征矩阵;/>
Figure BDA0003938623570000077
为节点表示矩阵;n为节点的数量;/>
Figure BDA0003938623570000078
为总的配对平方距离;s为用来控制缩放范围的超参数;/>
Figure BDA0003938623570000079
为2-范数的平方;算式/>
Figure BDA00039386235700000710
表示去中心化,用于对每一行的特征矩阵减去对应行特征向量的均值,对整体的数值进行中心化处理;算式/>
Figure BDA00039386235700000711
用于对特征矩阵进行重新缩放操作,让整体的节点的嵌入向量不再趋于一致,让整体节点之间的特征向量嵌入更加符合真实情况的节点分布,从而缓解随着网络层数增加造成的过平滑所带来的性能下降;
S2.各个客户端从服务器上获取教师模型;
S3.基于知识蒸馏框架和本地私有数据,各个客户端采用步骤S2获取的教师模型对自身的图神经网络模型进行训练,得到各个客户端自身的本地模型;具体包括如下步骤:
若客户端采用的客户端模型包含GNN层,则采用H-KD蒸馏方法进行有图知识蒸馏;具体包括如下步骤:
在H-KD蒸馏方法中,本地模型需要学习的知识包括全连接层的输出和最终预测类别的软标签;采用如下算式计算软标签:
Figure BDA0003938623570000081
式中pi(zi,T)为第i类的类概率;zi为全连接层输出值z的第i维值;T为温度,用于控制软标签的重要性;k为模型预测类别的数量;
在H-KD蒸馏方法中,整体损失函数包括蒸馏损失和交叉熵损失;
采用如下算式计算蒸馏损失LD
Figure BDA0003938623570000082
式中式中pi(ti,T)为教师模型输出的软标签;ti为教师模型输出;si为学生模型的输出;pi(si,T)为学生模型输出的软标签;
采用如下算式计算交叉熵损失LS
Figure BDA0003938623570000091
式中yi为真实标签;pi(si,1)为温度T为1时的学生模型输出的软标签;
采用如下算式作为整体损失函数Ltotal
Ltotal=λLD+(1-λ)LS
式中λ为权重值,可以自行设定,也可以在训练过程动态调整;
若客户端采用的客户端模型为MLP模型,则进行无图知识蒸馏;本地模型为MLP模型;
训练时的损失函数包括蒸馏损失L'D和交叉熵损失L'S
计算蒸馏损失L'D
Figure BDA0003938623570000092
计算交叉熵损失L'S
Figure BDA0003938623570000093
采用如下算式作为整体损失函数Ltotal
Figure BDA0003938623570000094
式中λ'为权重值,可以自行设定,也可以在训练过程动态调整;V为所有节点;VL为带标签的节点;
S4.基于微调框架和本地私有数据,各个客户端对自身的本地模型进行训练,同时对获取的教师模型进行微调;步骤S1的预训练阶段的数据集并不能感知到全局数据,这样会导致D-GNN在客户端数据集的泛化上会存在偏差,针对该问题,本发明采取迁移学习领域常用的预训练——微调框架,在本地客户端蒸馏训练的同时,使用本地私有数据集对D-GNN模型进行微调;其作用是让D-GNN模型能够增加对本地数据集的泛化能力,经过微调后的模型能在知识蒸馏阶段更好指导本地模型的训练方向;
S5.基于联邦聚合算法(优选为FedAvg算法),服务器采用各个客户端的微调后的教师模型对服务器的教师模型进行更新;
迁移学习(本步骤)能够让D-GNN能理解每个客户端数据集的知识,为让服务器的模型能收集到全局知识更新,本文采取联邦学习的方式将每个客户端微调后的D-GNN模型的部分参数进行联邦汇聚;本发明采用FedAvg算法作为联邦聚合算法,通过FedAvg使用相对较少的通信轮次来提高模型性能;
S6.重复步骤S2~S5直至设定的条件,结束教师模型和本地模型的更新,完成基于知识蒸馏的联邦图学习。
以下结合实施例,对本发明的学习方法的有效性进行进一步验证:
本文实验采取了来自生物信息学和社交网络两个领域的五个真实世界数据集,每个数据集都有一组图,图的标签是二分类或多类别。数据集的汇总统计如表1所示:
表1数据集汇总统计数据示意表
数据集 图数量 特征 类别数 平均节点数 平均边数
NCI1 4110 37 2 29.87 32.30
PROTEINS 1113 3 2 39.06 72.82
IMDB-BIN 1000 135 2 19.77 96.53
IMDB-MUL 1500 88 3 12.00 65.94
REDDIT-BIN 2000 101 2 429.63 497.75
生物信息领域的数据集中的图节点都依据其生物特性而拥有相应的节点特征,而社交网络中节点不存在特征,本发明考虑的数据集都是无向图,因此针对无节点特征的社交数据,使用节点的度作为其特征。
在实验过程中,将数据集划分为公共数据集和私有数据集两个部分,划分比例为公共数据集占30%,私有数据集占70%。公共数据集用于模型的预训练,私有数据集随机分给不同客户端,按数据集大小分别设置每个客户端拥有的图数量为100-200。预训练和客户端本地训练的数据集分为训练集、测试集和验证集,其占数据样本总体比例为训练集占80%,验证集和测试集分别占10%。
在预训练阶段,先固定随机种子然后对公开数据集进行随机分割。实验中总共用200次训练迭代次数对预训练模型进行训练,并结合模型早停方法(Early Stopping)来增强模型泛化能力并减少预训练时间,训练中使用验证集作为其性能验证标准,并使用分类准确率作为评价指标,准确率计算公式为:正确分类样本数量/总样本数量;最后保存迭代训练中在验证集准确率最优模型相关参数。预训练模型即P-GIN网络设置为3层和5层结构,具体相关超参数如表2所示,其中优化器的选择为Adam。
表2超参数数据示意表
参数 数值 意义
随机种子 25 影响参数的初始化值
学习率 0.01 梯度更新的步长
激活函数 ReLu 防止出现梯度消失和梯度爆炸
参数初始化 Xavier 确保网络方差平稳
训练轮数 200 模型训练轮次
Dropout 0.5 缓解出现过拟合
使用最优模型在测试集的性能表现数据如表3所示:
表3最优模型在测试集的性能表现数据示意表
模型 P-3GIN P-4GIN P-5GIN P-1GIN
NCI1 80.6 79.2 82.6 73.3
PROTEINS 76.2 78.6 80.1 72.3
IMDB-BIN 78.8 70.1 73.3 63.1
IMDB-MUL 43.3 52.2 51.1 42.1
REDDIT-BIN 84.2 77.1 80.4 73.7
其中,GIN前面数字代表GIN网络层数量,表格数字是五次实验最佳模型在测试集的平均准确率。可以看出在NCI1和PROTEINS数据集上,随着GIN网络层数越深,模型准确率越高。而在社交领域数据上,最深网络P-5GIN表现不如略浅层的3-4层网络模型表现好,因为社交领域是将节点的度作为节点的特征信息,而分类任务上分别是电影类别(IMDB数据集)和社区讨论话题类别(REDDIT数据集),节点特征与类别关联性相对生物分子类的数据集略显薄弱,以PROTEINS为例,其节点特征具有生物学特性,与类别的关联性较大。而社交网络数据集的节点特征与类别关联性不大,所以在网络层加深,即使节点能捕获多阶邻接节点信息也没有性能上的提升,且在多类别分类中(IMDB-MUL)表现更差。基于此特点,本发明在后续实验中分别不同的实验数据设置不同的预训练模型架构。
本实验着重分析了本模型框架在不同图分类数据集上的分类性能,并对此进行评估分析。并选取了以下两种模型作为基准实验对比:
(1)无知识蒸馏本地训练(Local):Local即各个客户端仅使用本地数据集训练本地模型,不与服务器进行参数交换等通信过程,客户端模型的设置和数据集的划分如表4所示。
表4客户端模型的设置和数据集的划分示意表
数据集 教师模型 客户端数据量规模 客户端模型设置
NCI1 P-5GIN 250-300 3GCN+3SAGE+4MLP
PROTEINS P-4GIN 110-150 2GCN+2SAGE+2MLP
IMDB-BIN P-3GIN 90-120 2GCN+2SAGE+2MLP
IMDB-MUL P-4GIN 120-150 3GCN+2SAGE+3MLP
REDDIT-BIN P-3GIN 120-150 3GCN+3SAGE+4MLP
(2)无联邦本地知识蒸馏训练(KD-Local):KD-Local使用上节所训练完的模型指导客户端进行知识蒸馏,客户端模型的参数设置和数据划分方式与Local一致。
在本发明模型框架设置中,NCI1和PROTEINS数据集设置预训练模型为P-5GIN,其它四个社交网络数据集分别选取在上节表现最优模型参数为预训练模型。图分类任务性能的评估结果如表5所示:
表5图分类任务性能的评估结果示意表
数据集 Local KD-Local 本发明
NCI1 58.2±1.90 70.2±0.77 81.3±0.76
PROTEINS 68.1±2.41 71.9±1.19 77.1±0.6
IMDB-BIN 65.7±1.90 72.3±0.97 78.6±0.9
IMDB-MUL 35.4±2.25 41.1±2.41 50.4±0.96
REDDIT-BIN 70.1±1.18 78.6±1.32 82.9±0.77
实验给出了所评估结果在十次实验的平均准确率和标准偏差,其中单次实验准确率计算是取所有客户端本地模型在其验证集数据上的平均准确率。
如实验结果所示,可以得出如下结论:
(1)本发明所提出的基于知识蒸馏的联邦学习方法在生物分子和社交网络等五个公开图分类数据集上都取得了最优的效果。且客户端模型在NCI1、REDDIT-BIN和IMDB-BIN数据集上的分类效果与教师模型效果相当,其他数据集上的性能虽然略差于教师模型,但也达到了不错的分类效果。
(2)另一方面,从实验可以明显看到客户端只使用本地数据进行训练的效果欠佳,这是因为本申请采取的客户端模型都是单层的GNN和MLP,对图数据的嵌入学习能力欠佳,不能很好的表征图结构信息,而通过结合知识蒸馏能有效提高本地模型的分类能力,相比Local有了小幅度的提升,再通过联邦学习进行信息交互后能进一步提高整体模型的分类性能。
目前,自动驾驶领域中实现对周围物体的建模并对行人或其他车辆进行预测一件非常重要的工作。基于此,自动驾驶才能更好地对自车进行合理的决策和轨迹规划。在实际中,自动驾驶的车辆会先通过自身配置的雷达和摄像头等设备识别周身物体并对其建模,抽取实例特征,在这个过程中,为了提高预测的准确性和可解释性,很多采用了图神经网络技术来建模人、车和物等各种实例之间的关系,可以方便不同实例之间进行信息交换,比如当行人走上斑马线时,代表斑马线和行人的两个实例之间就会产生信息交换并引导模型关注他们之间的联系。
然而,通过图神经网络达到非常好的意图预测和轨迹预测效果,往往需要非常大型的深层次的模型,这对计算的硬件要求较高,同时这也需要大量的真实数据来对图神经网络模型进行训练。但是在实际生活中,大量的真实数据往往由不同的车辆产生,隶属于不同的自动驾驶公司;而且基于隐私保护的要求,这些数据往往无法统一集中起来训练一个大型的图神经网络模型。再者,不同的车辆往往都有其固定的行驶区域,可能针对其特定区域的真实数据所训练的图神经网络模型的效果更好;而且,在行驶过程中,往往需要快速地做出预测,这对模型的推理速度有较高的要求,大型的图神经网络往往难以达到快速推理的效果,且不同的自动驾驶车辆由于芯片等硬件条件的不同,能够支持的模型大小也不尽相同。
因此,这种情况就特别适用于本发明的技术方案。针对上述问题,本发明提供了一种包括所述基于知识蒸馏的联邦图学习方法的自动驾驶方法,具体包括如下步骤(方法流程示意图如图2所示):
A.确定服务器上的初始教师模型和自动驾驶车辆上的初始本地模型;
B.采用所述的基于知识蒸馏的联邦图学习方法,进行初始教师模型和初始本地模型的学习和更新,得到最终的本地模型;
C.自动驾驶车辆采用步骤B得到的本地模型,在自动驾驶过程中进行周围环境中的物体进行识别;
D.根据步骤C的识别结果,对车辆进行控制,完成车辆的自动驾驶。
步骤B所述的基于知识蒸馏的联邦图学习方法,包括如下步骤:
B1.在服务器上,通过公共数据集对图神经网络进行训练,得到教师模型;
B2.各个客户端从服务器上获取教师模型;
B3.基于知识蒸馏框架和本地私有数据,各个客户端采用步骤B2获取的教师模型对自身的图神经网络模型进行训练,得到各个客户端自身的本地模型;
B4.基于微调框架和本地私有数据,各个客户端对自身的本地模型进行训练,同时对获取的教师模型进行微调;
B5.基于联邦聚合算法,服务器采用各个客户端的微调后的教师模型对服务器的教师模型进行更新;
B6.重复步骤B2~B5直至设定的条件,结束教师模型和本地模型的更新,完成基于知识蒸馏的联邦图学习。
步骤B1所述的图神经网络,具体包括如下内容:
图神经网络为多层图同构神经网络;多层图同构神经网络能够将图神经网络中每个节点的邻居通过邻域聚合的方式汇聚局部空间的结构信息,然后再与自身节点的特征进行比例混合,最后采用能够映射到任意特征空间的全连接层来保证网络的单射特性,从而最大化图模型的表达能力;
采用如下算式作为多层图同构神经网络的计算式:
Figure BDA0003938623570000161
式中
Figure BDA0003938623570000162
为节点v在k层的隐藏特征;MLP(k)()为多层感知器;ε(k)为待学习的参数或者为一个设定的固定值,用于控制节点v特征在GNN迭代中的影响大小;u为邻居节点;N(v)为邻居节点的集合;/>
Figure BDA0003938623570000163
为邻居节点u在(k-1)层的隐藏特征;
同时,在多层图同构神经网络之间增加PairNorm操作,从而解决GNN随着网络层数增加而造成的过平滑问题;采用如下算式作为PairNorm操作的计算式:
Figure BDA0003938623570000164
Figure BDA0003938623570000165
式中
Figure BDA0003938623570000166
为中心化特征矩阵;/>
Figure BDA0003938623570000167
为节点表示矩阵;n为节点的数量;/>
Figure BDA0003938623570000168
为总的配对平方距离;s为用来控制缩放范围的超参数;/>
Figure BDA0003938623570000169
为2-范数的平方;算式/>
Figure BDA00039386235700001610
表示去中心化,用于对每一行的特征矩阵减去对应行特征向量的均值,对整体的数值进行中心化处理;算式/>
Figure BDA0003938623570000171
用于对特征矩阵进行重新缩放操作,让整体的节点的嵌入向量不再趋于一致,让整体节点之间的特征向量嵌入更加符合真实情况的节点分布,从而缓解随着网络层数增加造成的过平滑所带来的性能下降。
步骤B3所述的基于知识蒸馏框架和本地私有数据,各个客户端采用步骤B2获取的教师模型对自身的图神经网络模型进行训练,具体包括如下步骤:
若客户端采用的客户端模型包含GNN层,则采用H-KD蒸馏方法进行有图知识蒸馏;
若客户端采用的客户端模型为MLP模型,则进行无图知识蒸馏。
所述的采用H-KD蒸馏方法进行有图知识蒸馏,具体包括如下步骤:
在H-KD蒸馏方法中,本地模型需要学习的知识包括全连接层的输出和最终预测类别的软标签;采用如下算式计算软标签:
Figure BDA0003938623570000172
式中pi(zi,T)为第i类的类概率;zi为全连接层输出值z的第i维值;T为温度,用于控制软标签的重要性;k为模型预测类别的数量;
在H-KD蒸馏方法中,整体损失函数包括蒸馏损失和交叉熵损失;
采用如下算式计算蒸馏损失LD
Figure BDA0003938623570000173
式中pi(ti,T)为教师模型输出的软标签;ti为教师模型输出;si为学生模型的输出;pi(si,T)为学生模型输出的软标签;
采用如下算式计算交叉熵损失LS
Figure BDA0003938623570000181
式中yi为真实标签;pi(si,1)为温度T为1时的学生模型输出的软标签;
采用如下算式作为整体损失函数Ltotal
Ltotal=λLD+(1-λ)LS
式中λ为权重值。
所述的进行无图知识蒸馏,具体包括如下步骤:
本地模型为MLP模型;
训练时的损失函数包括蒸馏损失L'D和交叉熵损失L'S
计算蒸馏损失L'D
Figure BDA0003938623570000182
计算交叉熵损失L'S
Figure BDA0003938623570000183
采用如下算式作为整体损失函数Ltotal
Figure BDA0003938623570000184
式中λ'为权重值;V为所有节点;VL为带标签的节点。
步骤B5所述的联邦聚合算法,具体为FedAvg算法。
本发明提供的这种自动驾驶方法,尤其适用于现今实际生活中的自动驾驶过程;本发明提供的这种自动驾驶方法,最终能够在各车辆的客户端实现了效果好且推理速度快的本地模型(这种本地模型是小型模型,但是具有大型模型的性能),极大地提升了自动驾驶中对行人或其他车辆进行意图预测和轨迹预测的能力。

Claims (3)

1.一种基于知识蒸馏的联邦图学习方法的自动驾驶方法,具体包括如下步骤:
A.确定服务器上的初始教师模型和自动驾驶车辆上的初始本地模型;
B.采用基于知识蒸馏的联邦图学习方法,进行初始教师模型和初始本地模型的学习和更新,得到最终的本地模型;
C.自动驾驶车辆采用步骤B得到的本地模型,在自动驾驶过程中进行周围环境中的物体进行识别;
D.根据步骤C的识别结果,对车辆进行控制,完成车辆的自动驾驶;
其中,所述的基于知识蒸馏的联邦图学习方法,包括如下步骤:
S1.在服务器上,通过公共数据集对图神经网络进行训练,得到教师模型;
S2.各个客户端从服务器上获取教师模型;
S3.基于知识蒸馏框架和本地私有数据,各个客户端采用步骤S2获取的教师模型对自身的图神经网络模型进行训练,得到各个客户端自身的本地模型;具体包括如下步骤:
若客户端采用的客户端模型包含GNN层,则采用H-KD蒸馏方法进行有图知识蒸馏;具体包括如下步骤:
在H-KD蒸馏方法中,本地模型需要学习的知识包括全连接层的输出和最终预测类别的软标签;采用如下算式计算软标签:
Figure FDA0004228086050000011
式中pi(zi,T)为第i类的类概率;zi为全连接层输出值z的第i维值;T为温度,用于控制软标签的重要性;k为模型预测类别的数量;
在H-KD蒸馏方法中,整体损失函数包括蒸馏损失和交叉熵损失;
采用如下算式计算蒸馏损失LD
Figure FDA0004228086050000021
式中pi(ti,T)为教师模型输出的软标签;ti为教师模型输出;si为学生模型的输出;pi(si,T)为学生模型输出的软标签;
采用如下算式计算交叉熵损失LS
Figure FDA0004228086050000022
式中yi为真实标签;pi(si,1)为温度T为1时的学生模型输出的软标签;
采用如下算式作为整体损失函数Ltotal
Ltotal=λLD+(1-λ)LS
式中λ为权重值;
若客户端采用的客户端模型为MLP模型,则进行无图知识蒸馏;具体包括如下步骤:
本地模型为MLP模型;
训练时的损失函数包括蒸馏损失L'D和交叉熵损失L'S
计算蒸馏损失L'D
Figure FDA0004228086050000023
计算交叉熵损失L'S
Figure FDA0004228086050000024
采用如下算式作为整体损失函数L'total
Figure FDA0004228086050000025
式中λ'为第二权重值;V为所有节点;VL为带标签的节点;
S4.基于微调框架和本地私有数据,各个客户端对自身的本地模型进行训练,同时对获取的教师模型进行微调;
S5.基于联邦聚合算法,服务器采用各个客户端的微调后的教师模型对服务器的教师模型进行更新;
S6.重复步骤S2~S5直至设定的条件,结束教师模型和本地模型的更新,完成基于知识蒸馏的联邦图学习。
2.根据权利要求1所述的基于知识蒸馏的联邦图学习方法的自动驾驶方法,其特征在于步骤S1所述的图神经网络,具体包括如下内容:
图神经网络为多层图同构神经网络;多层图同构神经网络能够将图神经网络中每个节点的邻居通过邻域聚合的方式汇聚局部空间的结构信息,然后再与自身节点的特征进行比例混合,最后采用能够映射到任意特征空间的全连接层来保证网络的单射特性,从而最大化图模型的表达能力;
采用如下算式作为多层图同构神经网络的计算式:
Figure FDA0004228086050000031
式中
Figure FDA0004228086050000032
为节点v在k层的隐藏特征;MLP(k)()为多层感知器;ε(k)为待学习的参数或者为一个设定的固定值,用于控制节点v特征在GNN迭代中的影响大小;u为邻居节点;N(v)为邻居节点的集合;/>
Figure FDA0004228086050000033
为邻居节点u在(k-1)层的隐藏特征;
同时,在多层图同构神经网络之间增加PairNorm操作,从而解决GNN随着网络层数增加而造成的过平滑问题;采用如下算式作为PairNorm操作的计算式:
Figure FDA0004228086050000034
Figure FDA0004228086050000041
式中
Figure FDA0004228086050000042
为中心化特征矩阵;/>
Figure FDA0004228086050000043
为节点表示矩阵;n为节点的数量;/>
Figure FDA0004228086050000044
为总的配对平方距离;s为用来控制缩放范围的超参数;/>
Figure FDA0004228086050000045
为2-范数的平方;算式/>
Figure FDA0004228086050000046
表示去中心化,用于对每一行的特征矩阵减去对应行特征向量的均值,对整体的数值进行中心化处理;算式/>
Figure FDA0004228086050000047
用于对特征矩阵进行重新缩放操作,让整体的节点的嵌入向量不再趋于一致,让整体节点之间的特征向量嵌入更加符合真实情况的节点分布,从而缓解随着网络层数增加造成的过平滑所带来的性能下降。
3.根据权利要求2所述的基于知识蒸馏的联邦图学习方法的自动驾驶方法,其特征在于步骤S5所述的联邦聚合算法,具体为FedAvg算法。
CN202211415148.2A 2022-11-11 2022-11-11 基于知识蒸馏的联邦图学习方法及自动驾驶方法 Active CN115907001B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211415148.2A CN115907001B (zh) 2022-11-11 2022-11-11 基于知识蒸馏的联邦图学习方法及自动驾驶方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211415148.2A CN115907001B (zh) 2022-11-11 2022-11-11 基于知识蒸馏的联邦图学习方法及自动驾驶方法

Publications (2)

Publication Number Publication Date
CN115907001A CN115907001A (zh) 2023-04-04
CN115907001B true CN115907001B (zh) 2023-07-04

Family

ID=86475550

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211415148.2A Active CN115907001B (zh) 2022-11-11 2022-11-11 基于知识蒸馏的联邦图学习方法及自动驾驶方法

Country Status (1)

Country Link
CN (1) CN115907001B (zh)

Families Citing this family (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116415005B (zh) * 2023-06-12 2023-08-18 中南大学 一种面向学者学术网络构建的关系抽取方法
CN117097797B (zh) * 2023-10-19 2024-02-09 浪潮电子信息产业股份有限公司 云边端协同方法、装置、系统、电子设备及可读存储介质
CN117236421B (zh) * 2023-11-14 2024-03-12 湘江实验室 一种基于联邦知识蒸馏的大模型训练方法
CN118228777A (zh) * 2024-03-01 2024-06-21 南京航空航天大学 通过无数据蒸馏的联邦学习实现异构感知自动驾驶的方法
CN117829320B (zh) * 2024-03-05 2024-06-25 中国海洋大学 一种基于图神经网络和双向深度知识蒸馏的联邦学习方法

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113298229A (zh) * 2021-04-12 2021-08-24 云从科技集团股份有限公司 联邦学习模型训练方法、客户端、服务器及存储介质
CN114241282A (zh) * 2021-11-04 2022-03-25 河南工业大学 一种基于知识蒸馏的边缘设备场景识别方法及装置
CN114297927A (zh) * 2021-12-28 2022-04-08 中国科学院自动化研究所 基于数据驱动的工业装备数字孪生构建维护方法及系统
CN114429219A (zh) * 2021-12-09 2022-05-03 之江实验室 一种面向长尾异构数据的联邦学习方法
CN114943324A (zh) * 2022-05-26 2022-08-26 中国科学院深圳先进技术研究院 神经网络训练方法、人体运动识别方法及设备、存储介质
CN115115862A (zh) * 2022-05-20 2022-09-27 中国科学院计算技术研究所 基于异构图神经网络的高阶关系知识蒸馏方法及系统

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113298229A (zh) * 2021-04-12 2021-08-24 云从科技集团股份有限公司 联邦学习模型训练方法、客户端、服务器及存储介质
CN114241282A (zh) * 2021-11-04 2022-03-25 河南工业大学 一种基于知识蒸馏的边缘设备场景识别方法及装置
CN114429219A (zh) * 2021-12-09 2022-05-03 之江实验室 一种面向长尾异构数据的联邦学习方法
CN114297927A (zh) * 2021-12-28 2022-04-08 中国科学院自动化研究所 基于数据驱动的工业装备数字孪生构建维护方法及系统
CN115115862A (zh) * 2022-05-20 2022-09-27 中国科学院计算技术研究所 基于异构图神经网络的高阶关系知识蒸馏方法及系统
CN114943324A (zh) * 2022-05-26 2022-08-26 中国科学院深圳先进技术研究院 神经网络训练方法、人体运动识别方法及设备、存储介质

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
Federated Learning with Heterogeneous Architectures using Graph HyperNetworks;Or Litany 等;《arXiv:2201.08459v1 [cs.LG]》;全文 *
PAIRNORM: TACKLING OVERSMOOTHING IN GNNS;Lingxiao Zhao 等;《ICLR 2020》;全文 *
一文总览知识蒸馏概述;凉爽的安迪;《网页,https://mp.weixin.qq.com/s?__biz=MzI4MDYzNzg4Mw==&mid=2247493081&idx=6&sn=37df228117b8fcfe8d4f01f928fdf9fc&chksm=ebb7dd0ddcc0541b776bea4e5308928b919d2f6a2cb2f6b61863bf9d8b6c67362ac9bf21ea52&scene=27》;全文 *

Also Published As

Publication number Publication date
CN115907001A (zh) 2023-04-04

Similar Documents

Publication Publication Date Title
CN115907001B (zh) 基于知识蒸馏的联邦图学习方法及自动驾驶方法
CN102622515B (zh) 一种天气预测方法
CN106789149B (zh) 采用改进型自组织特征神经网络聚类算法的入侵检测方法
CN107944410A (zh) 一种基于卷积神经网络的跨领域面部特征解析方法
CN107203752A (zh) 一种联合深度学习和特征二范数约束的人脸识别方法
CN106980831A (zh) 基于自编码器的自亲缘关系识别方法
Yang et al. Federated continual learning via knowledge fusion: A survey
CN106647272A (zh) 基于k均值改进卷积神经网络的机器人路径规划方法
Zhang et al. Surface and high-altitude combined rainfall forecasting using convolutional neural network
Leibfried et al. A reward-maximizing spiking neuron as a bounded rational decision maker
Shen et al. An attention-based digraph convolution network enabled framework for congestion recognition in three-dimensional road networks
Chen et al. Feature extraction method of 3D art creation based on deep learning
CN108073978A (zh) 一种人工智能超深度学习模型的构成方法
CN116187469A (zh) 一种基于联邦蒸馏学习框架的客户端成员推理攻击方法
Li et al. Adaptive dropout method based on biological principles
CN114969078A (zh) 一种联邦学习的专家研究兴趣实时在线预测更新方法
Yang et al. Retinal vessel segmentation based on an improved deep forest
US20230118025A1 (en) Federated mixture models
Kim et al. K-FL: Kalman Filter-Based Clustering Federated Learning Method
CN109242089A (zh) 递进监督深度学习神经网络训练方法、系统、介质和设备
Bhaumik et al. STLGRU: Spatio-temporal lightweight graph GRU for traffic flow prediction
CN105809200A (zh) 一种生物启发式自主抽取图像语义信息的方法及装置
Xue et al. Tree-like branching network for multi-class classification
CN105389599A (zh) 基于神经模糊网络的特征选择方法
Wang et al. Visual information computing and processing model based on artificial neural network

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant