CN114418085A - 一种基于神经网络模型剪枝的个性化协作学习方法和装置 - Google Patents

一种基于神经网络模型剪枝的个性化协作学习方法和装置 Download PDF

Info

Publication number
CN114418085A
CN114418085A CN202111453868.3A CN202111453868A CN114418085A CN 114418085 A CN114418085 A CN 114418085A CN 202111453868 A CN202111453868 A CN 202111453868A CN 114418085 A CN114418085 A CN 114418085A
Authority
CN
China
Prior art keywords
model
models
training
local
importance
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111453868.3A
Other languages
English (en)
Inventor
徐恪
刘泱
赵乙
朱敏
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tsinghua University
Original Assignee
Tsinghua University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tsinghua University filed Critical Tsinghua University
Priority to CN202111453868.3A priority Critical patent/CN114418085A/zh
Publication of CN114418085A publication Critical patent/CN114418085A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Artificial Intelligence (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Probability & Statistics with Applications (AREA)
  • Feedback Control In General (AREA)

Abstract

本发明公开了一种基于神经网络模型剪枝的个性化协作学习方法和装置,其中,该方法包括:利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备;各边缘设备接收到全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;各边缘设备通过参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。本发明能够保持模型对本地数据的适应能力,同时能够增强模型的泛化能力,在参数聚合时能避免数据分布差异过大的模型相互干扰。

Description

一种基于神经网络模型剪枝的个性化协作学习方法和装置
技术领域
本发明涉及深度学习(协作学习)领域,尤其涉及一种基于神经网络模型剪枝的个性化协作学习方法和装置。
背景技术
近些年来,随着大数据时代的到来以及计算设备算力的大幅度提升,深度学习技术已经得到充分的发展与应用,比如在计算视觉、自然语言处理、自动驾驶和网络空间安全等领域都有非常丰富的落地场景。以往深度学习训练模型的方式需要将大量的数据汇集到数据中心,进行集中式的训练。这种集中式的训练方式带来两个问题。首先,从数以千万记的边缘设备上采集数据将会带来庞大的上行带宽消耗。以中国和美国为代表的许多国家,互联网络的上行带宽远小于下行带宽,大规模的上传数据很容易造成网络拥塞,降低系统的运行效率;更令人担忧的是,将边缘设备上的数据传输到中心服务器可能泄露边缘设备用户的数据隐私,带来巨大的安全隐患。
协作学习作为一种新型深度学习范式较好地解决了以上两个问题。协作学习是以一个中心服务器(群)组织若干边缘设备进行模型训练。中心服务器将全局模型下发至各个边缘设备。各边缘设备使用各自的数据以及梯度下降算法在本地对模型参数进行更新,待完成参数更新后,再将模型的更新结果上传至中心服务器。中心服务器对从各个边缘设备收到的模型更新结果进行参数聚合,从而实现全局模型的训练。在协作学习的过程中,用户数据始终保留在设备本地,边缘设备与中心服务器之间只传输模型参数的更新结果,极大程度地避免了数据隐私的泄露。但是协作学习依然面临诸多棘手问题:其一,通信效率的问题依然没有得到彻底解决。在协作学习过程中,需要在边缘设备和中心服务器之间相互传输模型更新结果,如果模型参数量很大,将会消耗大量的网络传输资源。其二,对于集中式的模型训练方式,默认的前提假设为数据分布是独立同分布的,并且这一假设在该训练方式下也往往能够成立。而在真实的分布式互联网架构的场景下,各边缘设备上的数据差异很大,数据分布不再满足这一假设,这对于模型的训练效果将产生很大的负面影响。
发明内容
本发明旨在至少在一定程度上解决相关技术中的技术问题之一。
为此,本发明的目的在于设计一种通信高效的个性化协作学习方法,使之在非独立同分布的数据环境下依然能够取得很好的模型训练效果,在各边缘设备的数据分布差异很大的情况下,依然能够训练得到适应于各自数据的个性化模型,并进一步提升模型的预测准确度;通过生成掩码矩阵对网络模型进行剪枝,为进一步压缩模型的传输规模奠定了基础;相比于主流协作学习方法仅多引入了目标裁剪率这一超参数,使得本方法的调优十分容易,也使得本方法能够快速、可靠地部署到各种现实的复杂环境当中。
本发明的另一个目的在于提出一种基于神经网络模型剪枝的个性化协作学习装置。
为达上述目的,本发明一方面提出了一种基于神经网络模型剪枝的个性化协作学习方法,包括以下步骤:
S1,利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备;
S2,各边缘设备接收到全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;以及,
S3,各边缘设备通过参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。
本发明实施例的基于神经网络模型剪枝的个性化协作学习方法,利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备;各边缘设备接收到全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;各边缘设备通过参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。本发明能够保持模型对本地数据的适应能力,也能够增强模型的泛化能力,并为压缩参数矩阵奠定了基础,在参数聚合时还能够避免数据分布差异过大的模型相互干扰,从而实现高效通信的个性化模型训练。
另外,根据本发明上述实施例的基于神经网络模型剪枝的个性化协作学习方法还可以具有以下附加的技术特征:
进一步地,步骤S2,包括:
S21,利用边缘设备Ck基于全局模型
Figure BDA0003387171030000021
使用Ck本地数据进行训练,至收敛时停止,得到模型ω′k
S22,基于模型ω′k对参数的重要性进行评估,得到参数wij的重要性权值Ωij
S23,基于重要性权值Ωij得到重要性权值矩阵Ωk,根据重要性权值矩阵Ωk生成参数掩码矩阵mk
进一步地,步骤S22中的重要性权值Ωij,包括:
重要性权值Ωij,是根据如下等式计算得到:
Figure BDA0003387171030000031
其中,Ndp是评估模型参数重要性的过程中所使用的数据样例的数量,gij(xd)是参数wij对于数据样例xd的更新梯度;
更新梯度gij(xd),是根据如下等式计算得到:
Figure BDA0003387171030000032
其中,F(xd;w)为模型w在数据样例xd上的输出,
Figure BDA0003387171030000033
为L2范数。
进一步地,步骤S23,包括:
根据目标裁剪率p,对于每层神经网络,对重要性权值矩阵的元素按照绝对值大小进行排序,裁剪绝对值最小的p比例的元素对应的权重,则掩码矩阵mk对应位置的元素值为0,未被裁剪的权重的对应位置的元素值为1,以得到参数掩码矩阵mk
进一步地,步骤S3,包括:
S31,对于N个边缘设备以及随机采样率K,随机采样N*K个边缘设备参与当前轮协作训练,则参与第r轮协作训练的边缘设备数量为s=max(N*K,1),边缘设备构成集合Sr={C1,...,Cs};
S32,中心服务器将全局模型
Figure BDA0003387171030000034
下发至S31中选出的边缘设备Sr,各边缘设备 Ck∈Sr接收到全局模型
Figure BDA0003387171030000035
后,对全局模型
Figure BDA0003387171030000036
使用参数掩码矩阵mk进行裁剪,为
Figure BDA0003387171030000037
S33,利用边缘设备Ck对模型
Figure BDA0003387171030000038
的训练,将训练完成后的模型
Figure BDA0003387171030000039
上传至中心服务器;
S34:中心服务器收集到所述Sr集合中所有边缘设备的模型,或超过预定的等待时间,将对已收集到的模型进行参数聚合。
进一步地,步骤S33,包括:
基于边缘设备Ck对模型
Figure BDA00033871710300000310
进行训练,为如下等式:
Figure BDA00033871710300000311
其中,η表示学习率,L(·)为进行优化的目标函数,⊙表示按元素对应位置进行的乘法。
进一步地,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,包括:
对边缘设备上的模型进行参数迭代的同时,对本地模型ω′k进行知识蒸馏,则目标函数为如下等式:
L=Lhard+αLsoft=CE(y,p)+αCE(q,p)
其中,CE(·)为交叉熵,y为数据样例真实标签的独热编码,p为学生网络的输出结果,q为对教师网络的输出软化后的结果。
进一步地,基于所述对教师网络的输出软化后的结果,以作为学生网络的软标签,为如下等式:
Figure BDA0003387171030000041
其中,zi是神经网络在softmax层之前的输出,T是蒸馏温度。
为达到上述目的,本发明另一方面提出了一种基于神经网络模型剪枝的个性化协作学习装置,包括:
第一评估模块,用于利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备;第二评估模块,用于各边缘设备接收到所述全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;以及,
训练模块,用于各边缘设备通过参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。
本发明实施例的基于神经网络模型剪枝的个性化协作学习装置,利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备;各边缘设备接收到全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;将各边缘设备通过参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。本发明能够保持模型对本地数据的适应能力,也能够增强模型的泛化能力,并为压缩参数矩阵奠定了基础,在参数聚合时还能够避免数据分布差异过大的模型相互干扰,从而实现高效通信的个性化模型训练。
本发明的有益效果:
首先,各边缘设备的参数掩码矩阵是由对其设备上的数据训练得到的,能够反映边缘设备的数据特征,并且“找到”适合于该数据特征的模型子结构。相似的模型子结构的协作训练能够增强彼此的泛化能力,提升预测准确性;不相似的模型子结构在参数聚合时也不会互相干扰,避免各自预测能力的退化;
其次,本发明采用知识蒸馏技术,在通过协作训练增强被剪枝模型泛化能力的同时,也使其不会随着训练的进行而“遗忘”本地数据的特征;就通信效率而言,各边缘设备的参数掩码矩阵能够非常稀疏,从而使得模型的参数矩阵十分稀疏,为模型参数矩阵的压缩奠定了良好基础,从而有效实现了在协作训练过程中的高效通信。
最后,本发明中的基于模型参数对数据样例敏感度的参数重要性评价方法,使得所获得的参数掩码矩阵更具可信性,更有助于刻画模型的最优子结构。
总之,本发明既实现个性化的模型训练,也保持了高效的通信,节省了网络带宽。并且,本发明还具有超参数数量少,易于调优的特点,能够方便快捷地在各种实际复杂场景中部署、使用。
本发明附加的方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。
附图说明
本发明上述的和/或附加的方面和优点从下面结合附图对实施例的描述中将变得明显和容易理解,其中:
图1为根据本发明实施例的基于神经网络模型剪枝的个性化协作学习方法的流程图;
图2为根据本发明实施例的基于神经网络模型剪枝的个性化协作学习方法的示意图;
图3为根据本发明实施例的边缘设备端的协作学习方法工作的示意图;
图4为根据本发明实施例的中心服务器进行参数聚合工作的示意图;
图5为根据本发明实施例的一种基于神经网络模型剪枝的个性化协作学习装置的结构示意图;
图6为根据本发明实施例的另一种基于神经网络模型剪枝的个性化协作学习装置的结构示意图。
具体实施方式
需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本发明。
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
下面参照附图描述根据本发明实施例提出的基于神经网络模型剪枝的个性化协作学习方法及装置,首先将参照附图描述根据本发明实施例提出的基于神经网络模型剪枝的个性化协作学习方法。
图2为基于模型剪枝与知识蒸馏的个性化协作学习方法的示意图,如图2所示,该方法分为模型参数评估阶段和正式训练阶段。在模型参数评估阶段,中心服务器向参与协作学习的每个边缘设备下发全局模型。各边缘设备分别对该模型进行训练,直至模型达到基本收敛,从而得到本地模型。通过累计本地模型中参数对本地数据样例的梯度之和,评估该模型各个参数对数据样例的敏感度,并以此作为模型参数重要性的评价指标。在此基础上给定目标裁剪率p,对作为本地模型的神经网络逐层生成用于模型剪枝的掩码矩阵。在正式训练阶段,各边缘设备通过该掩码矩阵,对全局模型的另一拷贝,即参与协作训练的模型进行剪枝,各边缘设备剪枝后的模型参与协作学习训练。在训练过程中,剪枝后的模型会作为学生网络对本地模型进行知识蒸馏,以保持对本地数据的适应能力。当各边缘设备完成本地的训练过程,将剪枝后的模型上传至中心服务器,中心服务器只对没有被剪去的参数进行聚合,以此增强数据分布相似的模型的相互学习,避免了数据分布相差过大的模型相互干扰。本发明的创新性在于,首次提出了网络剪枝-知识蒸馏-子网络参数聚合的协作学习方法,在边缘设备的数据分布极不均匀的情况下,保持高效通信的同时使各边缘设备能够训练个性化的模型,增强了边缘设备对本地数据的预测能力。
图1是本发明一个实施例的基于神经网络模型剪枝的个性化协作学习方法的流程图。
如图1所示,该基于神经网络模型剪枝的个性化协作学习方法包括以下步骤:
S1,利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备。
具体地,本发明的中心服务器初始化全局模型
Figure BDA0003387171030000061
并将
Figure BDA0003387171030000062
下发至各边缘设备。在本发明一个实施例中所提出的协作学习方法中,各个边缘设备将使用与全局模型的结构相同的网络,都是由中心服务器下发所得到。
S2,各边缘设备接收到全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵。
具体地,各边缘设备Ck(k=1,...,N)接收到中心服务器下发的全局模型
Figure BDA0003387171030000063
后,各自独立地开始对模型参数的重要性进行评估,并且生成参数掩码矩阵:
(2.1)边缘设备Ck基于全局模型
Figure BDA0003387171030000064
使用Ck本地数据进行训练,至收敛时停止,得到模型ω′k
(2.2)边缘设备Ck基于模型ω′k对参数的重要性进行评估,得到参数wij的重要性权值Ωij。Ωij依照如下等式计算得到:
Figure BDA0003387171030000071
其中,Ndp是评估模型参数重要性的过程中所使用的数据样例的数量,gij(xd)是参数wij。对于数据样例xd的更新梯度,依照如下等式计算得到:
Figure BDA0003387171030000072
其中,F(xd;w)为模型w在数据样例xd上的输出,
Figure BDA0003387171030000073
为L2范数。
(2.3)边缘设备Ck基于重要性权值矩阵Ωk生成参数掩码矩阵mk。按照给定的目标裁剪率p,对于每层神经网络,对重要性权值矩阵的元素按照绝对值大小进行排序,绝对值最小的p比例的元素对应的权重会被裁剪,即掩码矩阵mk对应位置的元素值为0,未被裁剪的权重的对应位置的元素值为1,从而得到参数掩码矩阵mk
S3,各边缘设备通过参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。
具体地,步骤S3,包括:
(3.1)对于N个边缘设备以及随机采样率K,随机采样N*K个边缘设备参与当前轮协作训练,则参与第r轮协作训练的边缘设备数量为s=max(N*K,1),边缘设备构成集合Sr={C1,...,Cs}。
(3.2)中心服务器将全局模型
Figure BDA0003387171030000074
下发至步骤(3.1)中选出的边缘设备Sr,各边缘设备Ck∈Sr接收到全局模型
Figure BDA0003387171030000075
后,对全局模型
Figure BDA0003387171030000076
使用参数掩码矩阵mk进行裁剪,为
Figure BDA0003387171030000077
随后,边缘设备Ck对模型
Figure BDA0003387171030000078
使用小批量梯度下降法进行训练,为如下等式:
Figure BDA0003387171030000079
其中,η表示学习率,L(·)为进行优化的目标函数,⊙表示按元素对应位置进行的乘法。
边缘设备上的模型进行参数迭代的同时,对本地模型ω′k进行知识蒸馏,这一过程体现于目标函数,如下面等式所示:
L=Lhard+αLsoft=CE(y,p)+αCE(q,p)
其中,CE(·)为交叉熵,y为数据样例真实标签的独热编码,p为学生网络的输出结果,q为对教师网络的输出软化后的结果,以此作为学生网络的“软标签”,具体如下面的等式所示:
Figure BDA0003387171030000081
其中,zi是神经网络在softmax层之前的输出,T是蒸馏温度。
在边缘设备端的协作学习方法工作流程示意图,如图3所示。
(3.3)基于边缘设备Ck对模型
Figure BDA0003387171030000082
的训练,将训练完成后的模型
Figure BDA0003387171030000083
上传至中心服务器。
(3.4)中心服务器收集到Sr集合中所有边缘设备的模型,或超过预定的等待时间,将对已收集到的模型进行参数聚合。对于模型中同一位置的权重参数,服务器只对未被裁剪的参数进行算数平均,即是忽略该位置的权重参数为0的模型,如图4所示。
通过上述步骤,利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备;各边缘设备接收到全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;各边缘设备通过参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。本发明能够保持模型对本地数据的适应能力,也能够增强模型的泛化能力,并为压缩参数矩阵奠定了基础,在参数聚合时还能避免数据分布差异过大的模型相互干扰,从而实现高效通信的个性化模型训练。
综上,本发明能够广泛应用于个人移动设备、物联网设备以及边缘计算节点,在边缘端进行模型训练与参数更新,该协作学习方法相比于之前的主流方法,只多引入了目标裁剪率p这一超参数,使之易于在各种复杂环境下部署,具有很高的实用价值和广泛的应用前景。
综上所述,为了解决协作学习的通信效率和数据非独立同分布的问题,本发明提出了一种基于模型剪枝与知识蒸馏的个性化协作学习方法。在边缘设备本地,本发明首先训练一个基本收敛的模型,以此得到模型中参数对本地数据的敏感程度,再根据目标裁剪率得到掩码矩阵。在协作训练过程中,使用掩码矩阵对参与协作训练的模型进行非结构化剪枝,并且同时作为学生网络,对之前的本地模型进行知识蒸馏。这样的训练方式既能够保持模型对本地数据的适应能力,也能够增强模型的泛化能力。此外这种训练方式也为压缩参数矩阵奠定了基础,在参数聚合时还能避免数据分布差异过大的模型相互干扰,从而实现高效通信的个性化模型训练。
为了实现上述实施例,如图5所示,本实施例中还提供了一种基于神经网络模型剪枝的个性化协作学习装置10,该装置10包括:第一评估模块100、第二评估模块200和训练模块300。
第一评估模块100,用于利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备;
第二评估模块200,用于各边缘设备接收到所述全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;
训练模块300,用于各边缘设备通过所述参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。
如图6所示,本发明实施例中,第二评估模块200,包括:
第三评估子模块201,用于利用边缘设备Ck基于全局模型
Figure BDA0003387171030000091
使用Ck本地数据进行训练,至收敛时停止,得到模型ω′k
第四评估子模块202,用于基于模型ω′k对参数的重要性进行评估,得到参数wij的重要性权值Ωij
第五评估子模块203,用于基于重要性权值Ωij得到重要性权值矩阵Ωk,根据重要性权值矩阵Ωk生成参数掩码矩阵mk
根据本发明实施例的基于神经网络模型剪枝的个性化协作学习装置,利用中心服务器初始化全局模型,并将全局模型下发至各边缘设备;各边缘设备接收到全局模型后,分别对全局模型进行训练以得到本地模型,基于本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;各边缘设备通过参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对本地模型进行知识蒸馏,并将剪枝后的模型上传至中心服务器,以对未被剪去的参数进行聚合。本发明能够保持模型对本地数据的适应能力,也能够增强模型的泛化能力,并为压缩参数矩阵奠定了基础,在参数聚合时还能够避免数据分布差异过大的模型相互干扰,从而实现高效通信的个性化模型训练。
需要说明的是,前述对基于神经网络模型剪枝的个性化协作学习方法实施例的解释说明也适用于该实施例的基于神经网络模型剪枝的个性化协作学习装置,此处不再赘述。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括至少一个该特征。在本发明的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
尽管上面已经示出和描述了本发明的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本发明的限制,本领域的普通技术人员在本发明的范围内可以对上述实施例进行变化、修改、替换和变型。

Claims (10)

1.一种基于神经网络模型剪枝的个性化协作学习方法,其特征在于,包括以下步骤:
S1,利用中心服务器初始化全局模型,并将所述全局模型下发至各边缘设备;
S2,所述各边缘设备接收到所述全局模型后,分别对所述全局模型进行训练以得到本地模型,基于所述本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;以及,
S3,所述各边缘设备通过所述参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对所述本地模型进行知识蒸馏,并将所述剪枝后的模型上传至所述中心服务器,以对未被剪去的参数进行聚合。
2.根据权利要求1所述的基于神经网络模型剪枝的个性化协作学习方法,其特征在于,所述S2,包括:
S21,利用边缘设备Ck基于全局模型
Figure FDA0003387171020000011
使用Ck本地数据进行训练,至收敛时停止,得到模型ω′k
S22,基于所述模型ω′k对参数的重要性进行评估,得到参数wij的重要性权值Ωij
S23,基于所述重要性权值Ωij得到重要性权值矩阵Ωk,根据所述重要性权值矩阵Ωk生成参数掩码矩阵mk
3.根据权利要求2所述的基于神经网络模型剪枝的个性化协作学习方法,其特征在于,所述重要性权值Ωij,是根据如下等式计算得到:
Figure FDA0003387171020000012
其中,Ndp是评估模型参数重要性的过程中所使用的数据样例的数量,gij(xd)是参数wij对于数据样例xd的更新梯度;
所述更新梯度gij(xd),是根据如下等式计算得到:
Figure FDA0003387171020000013
其中,F(xd;w)为模型w在数据样例xd上的输出,
Figure FDA0003387171020000014
为L2范数。
4.根据权利要求2所述的基于神经网络模型剪枝的个性化协作学习方法,其特征在于,所述S23,包括:
根据目标裁剪率p,对于每层神经网络,对重要性权值矩阵的元素按照绝对值大小进行排序,裁剪绝对值最小的p比例的元素对应的权重,则掩码矩阵mk对应位置的元素值为0,未被裁剪的权重的对应位置的元素值为1,以得到所述参数掩码矩阵mk
5.根据权利要求1所述的基于神经网络模型剪枝的个性化协作学习方法,其特征在于,所述S3,包括:
S31,对于N个边缘设备以及随机采样率K,随机采样N*K个边缘设备参与当前轮协作训练,则参与第r轮协作训练的边缘设备数量为s=max(N*K,1),边缘设备构成集合Sr={C1,...,Cs};
S32,中心服务器将全局模型
Figure FDA0003387171020000021
下发至所述S31中选出的边缘设备Sr,各边缘设备Ck∈Sr接收到全局模型
Figure FDA0003387171020000022
后,对所述全局模型
Figure FDA0003387171020000023
使用参数掩码矩阵mk进行裁剪,为
Figure FDA0003387171020000024
S33,利用所述边缘设备Ck对所述模型
Figure FDA0003387171020000025
进行训练,并将训练完成后的模型
Figure FDA0003387171020000026
上传至所述中心服务器;
S34:所述中心服务器收集到所述Sr集合中所有边缘设备的模型,或超过预定的等待时间,将对已收集到的模型进行参数聚合。
6.根据权利要求5所述的基于神经网络模型剪枝的个性化协作学习方法,其特征在于,所述S33,包括:
基于所述边缘设备Ck对所述模型
Figure FDA0003387171020000027
进行训练,为如下等式:
Figure FDA0003387171020000028
其中,η表示学习率,L(·)为进行优化的目标函数,⊙表示按元素对应位置进行的乘法。
7.根据权利要求6所述的基于神经网络模型剪枝的个性化协作学习方法,其特征在于,对所述边缘设备上的模型进行参数迭代的同时,对所述本地模型ω′k进行知识蒸馏,则目标函数为如下等式:
L=Lhard+αLsoft=CE(y,p)+αCE(q,p)
其中,CE(·)为交叉熵,y为数据样例真实标签的独热编码,p为学生网络的输出结果,q为对教师网络的输出软化后的结果。
8.根据权利要求7所述的基于神经网络模型剪枝的个性化协作学习方法,其特征在于,基于所述对教师网络的输出软化后的结果,以作为学生网络的软标签,为如下等式:
Figure FDA0003387171020000029
其中,zi是神经网络在softmax层之前的输出,T是蒸馏温度。
9.一种基于神经网络模型剪枝的个性化协作学习装置,其特征在于,包括:
第一评估模块,用于利用中心服务器初始化全局模型,并将所述全局模型下发至各边缘设备;第二评估模块,用于所述各边缘设备接收到所述全局模型后,分别对所述全局模型进行训练以得到本地模型,基于所述本地模型对模型参数的重要性进行评估,并生成参数掩码矩阵;以及,
训练模块,用于所述各边缘设备通过所述参数掩码矩阵,对参与协作训练的模型进行剪枝,将剪枝后的模型作为学生网络对所述本地模型进行知识蒸馏,并将所述剪枝后的模型上传至所述中心服务器,以对未被剪去的参数进行聚合。
10.根据权利要求9所述的基于神经网络模型剪枝的个性化协作学习装置,其特征在于,所述第二评估模块,还包括:
第三评估子模块,用于利用边缘设备Ck基于全局模型
Figure FDA0003387171020000031
使用Ck本地数据进行训练,至收敛时停止,得到模型ω′k
第四评估子模块,用于基于所述模型ω′k对参数的重要性进行评估,得到参数wij的重要性权值Ωij
第五评估子模块,用于基于所述重要性权值Ωij得到重要性权值矩阵Ωk,根据所述重要性权值矩阵Ωk生成参数掩码矩阵mk
CN202111453868.3A 2021-12-01 2021-12-01 一种基于神经网络模型剪枝的个性化协作学习方法和装置 Pending CN114418085A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111453868.3A CN114418085A (zh) 2021-12-01 2021-12-01 一种基于神经网络模型剪枝的个性化协作学习方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111453868.3A CN114418085A (zh) 2021-12-01 2021-12-01 一种基于神经网络模型剪枝的个性化协作学习方法和装置

Publications (1)

Publication Number Publication Date
CN114418085A true CN114418085A (zh) 2022-04-29

Family

ID=81264667

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111453868.3A Pending CN114418085A (zh) 2021-12-01 2021-12-01 一种基于神经网络模型剪枝的个性化协作学习方法和装置

Country Status (1)

Country Link
CN (1) CN114418085A (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114881229A (zh) * 2022-07-07 2022-08-09 清华大学 一种基于参数渐次冻结的个性化协作学习方法和装置
CN114936078A (zh) * 2022-05-20 2022-08-23 天津大学 一种微网群边缘调度与智能体轻量化裁剪方法
CN115186937A (zh) * 2022-09-09 2022-10-14 闪捷信息科技有限公司 基于多方数据协同的预测模型训练、数据预测方法和装置
CN117194992A (zh) * 2023-11-01 2023-12-08 支付宝(杭州)信息技术有限公司 一种模型训练、任务执行方法、装置、存储介质及设备

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114936078A (zh) * 2022-05-20 2022-08-23 天津大学 一种微网群边缘调度与智能体轻量化裁剪方法
CN114881229A (zh) * 2022-07-07 2022-08-09 清华大学 一种基于参数渐次冻结的个性化协作学习方法和装置
CN114881229B (zh) * 2022-07-07 2022-09-20 清华大学 一种基于参数渐次冻结的个性化协作学习方法和装置
CN115186937A (zh) * 2022-09-09 2022-10-14 闪捷信息科技有限公司 基于多方数据协同的预测模型训练、数据预测方法和装置
CN117194992A (zh) * 2023-11-01 2023-12-08 支付宝(杭州)信息技术有限公司 一种模型训练、任务执行方法、装置、存储介质及设备
CN117194992B (zh) * 2023-11-01 2024-04-19 支付宝(杭州)信息技术有限公司 一种模型训练、任务执行方法、装置、存储介质及设备

Similar Documents

Publication Publication Date Title
CN114418085A (zh) 一种基于神经网络模型剪枝的个性化协作学习方法和装置
CN112367109B (zh) 空地网络中由数字孪生驱动的联邦学习的激励方法
Seo et al. Physics-aware difference graph networks for sparsely-observed dynamics
Nie et al. Network traffic prediction based on deep belief network in wireless mesh backbone networks
Hu et al. Federated region-learning: An edge computing based framework for urban environment sensing
CN106570597B (zh) 一种sdn架构下基于深度学习的内容流行度预测方法
Yuan et al. Parameter extraction of solar cell models using chaotic asexual reproduction optimization
Jiang et al. Deepurbanmomentum: An online deep-learning system for short-term urban mobility prediction
CN113191484A (zh) 基于深度强化学习的联邦学习客户端智能选取方法及系统
CN109462520B (zh) 基于lstm模型的网络流量资源态势预测方法
CN112532746B (zh) 一种云边协同感知的方法及系统
CN114912705A (zh) 一种联邦学习中异质模型融合的优化方法
CN110267292B (zh) 基于三维卷积神经网络的蜂窝网络流量预测方法
JP2013074695A (ja) 太陽光発電予測装置、太陽光発電予測方法及び太陽光発電予測プログラム
Gao et al. Federated region-learning for environment sensing in edge computing system
CN112232543A (zh) 一种基于图卷积网络的多站点预测方法
CN115660147A (zh) 一种基于传播路径间与传播路径内影响力建模的信息传播预测方法及系统
Nguyen et al. Spatially-distributed federated learning of convolutional recurrent neural networks for air pollution prediction
CN115115021A (zh) 基于模型参数异步更新的个性化联邦学习方法
Al-Omary et al. Prediction of energy in solar powered wireless sensors using artificial neural network
CN117236421A (zh) 一种基于联邦知识蒸馏的大模型训练方法
CN115168654A (zh) 一种基于多属性决策的超网络重要节点识别方法
CN116205383B (zh) 一种基于元学习的静态动态协同图卷积交通预测方法
CN107018019B (zh) 一种基于复杂演化网络的航班延误传播特性分析方法
Jafarkazemi et al. Performance prediction of flat-plate solar collectors using MLP and ANFIS

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination