CN116562390A - 多聚合节点联邦知识蒸馏学习方法、系统、设备及介质 - Google Patents
多聚合节点联邦知识蒸馏学习方法、系统、设备及介质 Download PDFInfo
- Publication number
- CN116562390A CN116562390A CN202310324054.2A CN202310324054A CN116562390A CN 116562390 A CN116562390 A CN 116562390A CN 202310324054 A CN202310324054 A CN 202310324054A CN 116562390 A CN116562390 A CN 116562390A
- Authority
- CN
- China
- Prior art keywords
- aggregation node
- model
- aggregation
- probability
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000004220 aggregation Methods 0.000 title claims abstract description 233
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 48
- 238000000034 method Methods 0.000 title claims abstract description 47
- 230000002776 aggregation Effects 0.000 claims abstract description 200
- 238000012937 correction Methods 0.000 claims abstract description 56
- 238000012549 training Methods 0.000 claims abstract description 35
- 238000003860 storage Methods 0.000 claims abstract description 19
- 239000013598 vector Substances 0.000 claims description 39
- 230000006870 function Effects 0.000 claims description 23
- 238000004364 calculation method Methods 0.000 claims description 22
- 239000011159 matrix material Substances 0.000 claims description 15
- 238000009826 distribution Methods 0.000 claims description 12
- 238000004590 computer program Methods 0.000 claims description 3
- 238000004891 communication Methods 0.000 abstract description 18
- 230000002411 adverse Effects 0.000 description 4
- 230000000694 effects Effects 0.000 description 4
- 230000008569 process Effects 0.000 description 4
- 230000004044 response Effects 0.000 description 4
- 230000009286 beneficial effect Effects 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 238000013459 approach Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000003068 static effect Effects 0.000 description 2
- 230000005540 biological transmission Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 239000012141 concentrate Substances 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000000873 masking effect Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
- G06F18/232—Non-hierarchical techniques
- G06F18/2321—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computing Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Mathematical Physics (AREA)
- Medical Informatics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请提供了多聚合节点联邦知识蒸馏学习方法、系统、设备及介质,方法包括对客户端的预测概率进行修正,根据计算得到的修正概率获取聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数;以各聚合节点的公共数据集的数据量为权重,计算聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数的加权平均数,得到聚合节点的更新模型参数和聚合节点更新模型;根据聚合节点更新模型,得到聚合节点更新模型在公共数据集上的预测概率并进行修正;基于知识蒸馏,根据计算得到的修正概率,训练与聚合节点连接的客户端的个性化模型。通过多个聚合节点分别建立公共数据集,共同承担通信压力和存储压力。
Description
技术领域
本申请涉及机器学习技术领域,尤其涉及多聚合节点联邦知识蒸馏学习方法、系统、设备及介质。
背景技术
联邦知识蒸馏学习的模型准确度依赖于中心节点的公共数据集。相关技术中,没有考虑到公共数据集对中心节点通信及存储的压力,因此,中心节点的通信能力的存储能力限制了公共数据集的大小,并限制了模型收敛速度,影响模型准确度。
发明内容
有鉴于此,本申请的目的在于提出多聚合节点联邦知识蒸馏学习方法、系统、设备及介质。
基于上述目的,本申请提供了一种多聚合节点联邦知识蒸馏学习方法,包括:
获取聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数;
根据聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数,得到聚合节点的更新模型参数和聚合节点更新模型;
基于知识蒸馏,根据聚合节点更新模型,训练与聚合节点连接的客户端的个性化模型。
本申请还提供了一种多聚合节点联邦知识蒸馏学习系统,包括:
聚合节点模型参数获取模块,用于获取聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数;
聚合节点更新模型获取模块,用于根据聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数,得到聚合节点的更新模型参数和聚合节点更新模型;
知识蒸馏模块,用于基于知识蒸馏,根据聚合节点更新模型,训练与聚合节点连接的客户端的个性化模型。
本申请还提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行程序时实现上述的方法。
本申请还提供了一种非暂态计算机可读存储介质,非暂态计算机可读存储介质存储计算机指令,计算机指令用于使计算机执行上述方法。
从上面所述可以看出,本申请提供的多聚合节点联邦知识蒸馏学习方法、系统、设备及介质,通过多个聚合节点分别建立公共数据集,共同承担通信压力和存储压力。降低对公共数据集大小的限制,从而提高模型收敛速度,提高模型准确度。
附图说明
为了更清楚地说明本申请中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例的多聚合节点联邦知识蒸馏学习方法的流程示意图。
图2为本申请实施例的多聚合节点联邦知识蒸馏学习系统的结构示意图。
图3为本申请实施例的服务器的硬件结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚明白,以下结合具体实施例,并参照附图,对本申请进一步详细说明。
需要说明的是,除非另外定义,本申请实施例使用的技术术语或者科学术语应当为本申请所属领域内具有一般技能的人士所理解的通常意义。“包括”或者“包含”等类似的词语意指出现该词前面的元件或者物件涵盖出现在该词后面列举的元件或者物件及其等同,而不排除其他元件或者物件。“连接”或者“相连”等类似的词语并非限定于物理的或者机械的连接,而是可以包括电性的连接,不管是直接的还是间接的。
联邦学习是一种将数据和模型解耦合的分布式框架,可以解决数据孤岛和隐私保护难题,整个过程不需要将参与方数据集中到一个中心存储点,在数据不离开本地的情况下,实现各参与方的联合建模。在联邦学习框架中,一般存在一个中心计算方,承担收集其他各方传递的模型参数信息,并经过相应算法更新,向各方返回任务,反复迭代直至收敛,最终构建一个有效的全局模型。整个过程客户端和服务器都无权获取和控制其他客户端的数据,构建联邦学习模型期间不影响客户端设备的正常使用,训练好的联邦学习模型可以在各数据参与方之间共享和部署,在智慧医疗、金融保险和智能物联网等领域有广泛的应用前景。
联邦知识蒸馏学习的模型准确度依赖于中心节点的公共数据集。相关技术中,没有考虑到公共数据集对中心节点通信及存储的压力,因此,中心节点的通信能力的存储能力限制了公共数据集的大小,并限制了模型收敛速度,影响模型准确度。
基于相关技术上述的缺陷,本申请实施例提供了多聚合节点联邦知识蒸馏学习方法、系统、设备及介质。
本申请提供的多聚合节点联邦知识蒸馏学习方法、系统、设备及介质,通过多个聚合节点分别建立公共数据集,共同承担通信压力和存储压力。降低对公共数据集大小的限制,从而提高模型收敛速度,提高模型准确度。
图1示出了本申请实施例所提供的多聚合节点联邦知识蒸馏学习方法的流程示意图。
如图1所示,本申请实施例提供了一种多聚合节点联邦知识蒸馏学习方法,包括:
步骤S101、获取聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数。
在本实施例中,所有聚合节点具有相同的模型结构。这样,所有聚合节点的模型参数的结构也相同,各聚合节点之间可以传输模型参数来进行参数优化。
作为一个可选的实施例,联邦知识蒸馏学习方法包括通过如下方法得到聚合节点的模型参数:
将客户端的数据划分为公共数据和私有数据。
在本实施例中,预先设置公共数据与私有数据的比例,客户端可以随机或人工选择公共数据。
根据与聚合节点连接的所有客户端的公共数据,得到公共数据集。
在本实施例中,客户端可以自行选择连接的聚合节点。在选择连接的聚合节点时,可以向客户端展示客户端与聚合节点之间的连接参数,客户端可以根据连接参数选择想要连接的聚合节点。具体实施时,连接参数可以包括丢包率、上下行速度和网络延迟。
将私有数据作为训练集训练客户端的初始模型,得到第一客户端模型。
在本实施例中,预先设置初始模型列表,初始模型列表包括多个初始模型。多个初始模型的模型参数具有不同的参数数量,因此,各个初始模型需要的设备算力不同,客户端可以根据自身的设备算力选择要使用的初始模型。
使用第一客户端模型对公共数据集进行预测,得到公共数据集中的每个数据对应的第一预测概率分布。
将每个数据对应的第一预测概率分布表示为行向量,再将所有数据对应的行向量作为列向量中的元素,得到第一预测概率。
对第一预测概率进行修正计算,得到第一修正概率。
对与聚合节点连接的所有客户端的第一修正概率进行聚合计算,得到聚合概率。
在本实施例中,通过对第一预测概率进行修正计算来提高概率的准确度,得到准确度更高的第一修正概率,并通过计算所有客户端的第一修正概率的平均数,得到聚合概率。这样,得到的聚合概率融合了与聚合节点连接的所有客户端的第一修正概率,能够用于训练聚合节点模型,并使根据聚合概率训练得到的聚合节点模型具有第一客户端模型的信息。
根据聚合概率设置损失函数,将公共数据集作为训练集训练聚合节点的初始模型,得到聚合节点模型。
在本实施例中,损失函数可以为下式:
其中,i为多聚合节点中聚合节点的序号,Lossi为第i个聚合节点的损失函数,λ为参数,该参数为超参数,在训练过程中根据损失函数的收敛情况进行调优,Nlabel为类别数,l为由标签类别组成的1×Nlabel的向量,pi(Xi,l,wi)为第i个聚合节点的聚合节点模型对公共数据集进行预测得到的预测概率,Xi为公共数据集中的数据的合集,Yi为数据对应的标签的合集,wi为第i个聚合节点的聚合节点模型的模型参数,Pavg(Xi,l)为聚合概率,KL(Pavg(Xi,l)||pi(Xi,l,wi)为计算聚合概率与pi(Xi,l,wi)之间的KL散度。
在使用上述损失函数训练聚合节点模型时,应用了知识蒸馏。知识蒸馏(knowledge distillation)是模型压缩的一种常用的方法,通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息,来训练这个小模型,以期达到更好的性能和精度。这个大模型称之为Teacher(教师模型),小模型称之为Student(学生模型)。来自Teacher模型输出的监督信息称之为knowledge(知识),而student学习迁移来自teacher的监督信息的过程称之为Distillation(蒸馏)。
这样,通过在损失函数中设置聚合概率,使损失函数收敛时,pi(Xi,l,wi)接近聚合概率,蒸馏聚合概率中的知识,从而使训练得到的聚合节点模型接近与该聚合节点连接的所有客户端的第一客户端模型。
根据聚合节点模型,得到聚合节点的模型参数。
在本实施例中,提取损失函数收敛时的聚合节点模型的模型参数作为聚合节点的模型参数。
步骤S102、根据聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数,得到聚合节点的更新模型参数和聚合节点更新模型。
这样,更新模型参数融合了聚合节点的模型参数和相邻聚合节点的模型参数,提高模型参数的拟合度,进而能够提高聚合节点模型的性能。
作为一个可选的实施例,步骤S102可以包括:
计算聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数的加权平均数,得到聚合节点的更新模型参数;其中,各个模型参数的权重为对应的聚合节点的公共数据集的数据量。
在本实施例中,可以在计算更新模型参数前,建立拓扑矩阵来描绘多个聚合节点之间的连接关系。拓扑矩阵的横排数与纵列数均对应聚合节点的序号,若两个聚合节点相连,则两个聚合节点在拓扑矩阵中的交叉位置设为1,若两个聚合节点不相连,则两个聚合节点在拓扑矩阵中的交叉位置设为0,聚合节点自身在拓扑矩阵中的交叉位置设为1,即拓扑矩阵中的主对角线全部设为1。例如,聚合节点1与聚合节点2相连,在拓扑矩阵T中,T12=T21=1;聚合节点1与聚合节点3不相连,在拓扑矩阵T中,T13=T31=0;并且,T11=T22=……=TMM=1。
在计算更新模型参数时,可以使用如下的公式:
其中,wi′为第i个聚合节点的更新模型参数,M为聚合节点的总数,Nm为第m个聚合节点的公共数据集数据量,tim为拓扑矩阵中的数值,代表第i个聚合节点与第m个聚合节点的连接关系,wm为第m个聚合节点的模型参数。
这样,在计算更新模型参数时,只需要使用上述通用公式和拓扑矩阵即可计算得到所有聚合节点的更新模型参数,无需对每个聚合节点进行分别计算,简化计算流程,提高效率。
步骤S103、基于知识蒸馏,根据聚合节点更新模型,训练与聚合节点连接的客户端的个性化模型。
作为一个可选的实施例,步骤S103可以包括:
使用聚合节点更新模型对公共数据集进行预测,得到公共数据集中的每个数据对应的第二预测概率分布。
将每个数据对应的第二预测概率分布表示为行向量,再将所有数据对应的行向量作为列向量中的元素,得到第二预测概率。
对第二预测概率进行修正计算,得到第二修正概率。
根据第二修正概率设置损失函数,将公共数据集作为训练集训练第一客户端模型,得到第二客户端模型。
在本实施例中,损失函数可以为下式:
Lossij=λ(-Yilog(pij(Xi,l,wij)))+(1-λ)KL(Pi(Xi,k,wi)||pij(Xi,l,wij))
其中,j为与第i个聚合节点连接的客户端的序号,Lossij为第i个聚合节点连接的第j个客户端的损失函数,Nlabel为类别数,l为由标签类别组成的1×Nlabel的向量,pij(Xi,l,wij)为第j个客户端的第二客户端模型对公共数据集进行预测得到的预测概率,wij为第j个客户端的第二客户端模型的模型参数,Pi(Xi,l,wi)为第二修正概率,KL(Pi(Xi,l,wi)||pij(Xi,l,wij)为计算第二修正概率与pij(Xi,l,wij)之间的KL散度。
这样,通过在损失函数中设置第二修正概率,使损失函数收敛时,pij(Xi,l,wij)接近第二修正概率,蒸馏第二修正概率中的知识,从而使训练得到的第二客户端模型接近与该客户端连接的聚合节点的聚合节点更新模型,使第二客户端模型学习到多个客户端的公共数据的知识。
将私有数据作为训练集训练第二客户端模型,得到客户端的个性化模型。
在本实施例中,损失函数可以为下式:
Lossij′=-Yijlog(pij(Xij,l,wij′))
其中,Lossij′为第i个聚合节点连接的第j个客户端的损失函数,Nlabel为类别数,l为由标签类别组成的1×Nlabel的向量,pij(Xij,l,wij′)为第j个客户端的个性化模型对私有数据进行预测得到的预测概率,Xij为私有数据中的数据的合集,wij′为第j个客户端的个性化模型的模型参数。
这样,使用私有数据对第二客户端模型进行训练,得到更契合客户端的数据的个性化模型。
考虑到当第一客户端模型对公共数据集进行预测出现失败时,失败预测对应的概率向量会对聚合节点训练模型造成不良影响,需要对失败预测对应的概率向量进行修正计算。
作为一个可选的实施例,使用第一客户端模型对公共数据集进行预测,得到第一预测概率;对第一预测概率进行修正计算,得到第一修正概率,包括:
使用第一客户端模型对公共数据集进行预测,得到第一模型输出和第一预测概率。
比较第一模型输出与公共数据集。
在本实施例中,比较第一模型输出中的每个输出与公共数据集中的对应标签。
响应于确定公共数据集中的一个数据的标签与该数据在第一模型输出中的标签不匹配,使用数据的独热编码修正第一预测概率中的对应行向量,得到第一修正概率。
独热编码(One-Hot Encoding),又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效。即,只有一位是1,其余都是零值。
具体实施时,可以通过如下公式得到第一修正概率:
Pij(Xi,l,w′ij)=ij(Xi,l,w′ij)*maskij+onehot(Yi)*(I-maskij),
其中,Nlabel为类别数,l为由标签类别组成的1×Nlabel的向量,Pij(Xi,l,wij′)为第一修正概率,pij(Xi,l,wij′)为第一预测概率,wij′为第一客户端模型的模型参数,I为由单位矩阵组成的向量,Ik=E(Nlabel),maskij为掩膜,onehot(Yi)为独热编码,/>为第k个数据在maskij中的第一掩膜数值,/>为第k个数据在第一模型输出中的输出,/>为公共数据集中的第k个数据,/>为第k个数据对应的标签。
这样,通过使用独热编码替换预测失败的数据的概率向量,避免对聚合节点训练模型造成不良影响。
考虑到当聚合节点更新模型对公共数据集进行预测出现失败时,失败预测对应的概率向量会对第二客户端模型造成不良影响,需要对失败预测对应的概率向量进行修正计算。
作为一个可选的实施例,使用聚合节点更新模型对公共数据集进行预测,得到第二预测概率;对第二预测概率进行修正计算,得到第二修正概率,包括:
使用聚合节点更新模型对公共数据集进行预测,得到第二模型输出和第二预测概率。
比较第二模型输出与公共数据集。
在本实施例中,比较第二模型输出中的每个输出与公共数据集中的对应标签。
响应于确定公共数据集中的一个数据的标签与该数据在第二模型输出中的标签不匹配,使用数据的独热编码修正第二预测概率中的对应行向量,得到第二修正概率。
具体实施时,可以通过如下公式得到第二修正概率:
Pi(Xi,l,w′i)=i(Xi,l,w′i)*maski+onehot(Yi)*(I-maski),
其中,Nlabel为类别数,l为由标签类别组成的1×Nlabel的向量,Pi(Xi,l,wi′)为第二修正概率,pij(Xi,l,wi′)为第二预测概率,wi′为更新模型参数,I为由单位矩阵组成的向量,Ik=E(Nlabel),maski为掩膜,为第k个数据在maski中的第二掩膜数值,/>为第k个数据在第二模型输出中的输出。
这样,通过使用独热编码替换预测失败的数据的概率向量,避免对第二客户端模型造成不良影响。
基于同一发明构思,与上述任意实施例方法相对应的,本公开还提供了一种多聚合节点联邦知识蒸馏学习系统。
图2示出了本申请实施例的多聚合节点联邦知识蒸馏学习系统的示意图。
参考图2,多聚合节点联邦知识蒸馏学习系统包括:
聚合节点模型参数获取模块,用于获取聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数。
聚合节点更新模型获取模块,用于根据聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数,得到聚合节点的更新模型参数和聚合节点更新模型。
具体用于计算聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数的加权平均数,得到聚合节点的更新模型参数。其中,各个模型参数的权重为对应的聚合节点的公共数据集的数据量。
知识蒸馏模块,用于基于知识蒸馏,根据聚合节点更新模型,训练与聚合节点连接的客户端的个性化模型。
具体用于使用聚合节点更新模型对公共数据集进行预测,得到公共数据集中的每个数据对应的第二预测概率分布。将每个数据对应的第二预测概率分布表示为行向量,再将所有数据对应的行向量作为列向量中的元素,得到第二预测概率。对第二预测概率进行修正计算,得到第二修正概率。根据第二修正概率设置损失函数,将公共数据集作为训练集训练第一客户端模型,得到第二客户端模型。将私有数据作为训练集训练第二客户端模型,得到客户端的个性化模型。
作为一个可选的实施例,多聚合节点联邦知识蒸馏学习系统还包括:
聚合节点模型参数获取模块,用于将客户端的数据划分为公共数据和私有数据。根据与聚合节点连接的所有客户端的公共数据,得到公共数据集。将私有数据作为训练集训练客户端的初始模型,得到第一客户端模型。使用第一客户端模型对公共数据集进行预测,得到公共数据集中的每个数据对应的第一预测概率分布。将每个数据对应的第一预测概率分布表示为行向量,再将所有数据对应的行向量作为列向量中的元素,得到第一预测概率。对第一预测概率进行修正计算,得到第一修正概率。对与聚合节点连接的所有客户端的第一修正概率进行聚合计算,得到聚合概率。根据聚合概率设置损失函数,将公共数据集作为训练集训练聚合节点的初始模型,得到聚合节点模型及聚合节点的模型参数。
第一修正概率获取模块,用于使用第一客户端模型对公共数据集进行预测,得到第一模型输出和第一预测概率。比较第一模型输出与公共数据集。响应于确定公共数据集中的一个数据与该数据在第一模型的输出不匹配,使用数据的独热编码修正第一预测概率中的对应行向量,得到第一修正概率。
聚合概率获取模块,用于计算所有客户端的第一修正概率的平均数,得到聚合概率。
第二修正概率获取模块,用于使用聚合节点更新模型对公共数据集进行预测,得到第二模型输出和第二预测概率。比较第二模型输出与公共数据集。响应于确定公共数据集中的一个数据与该数据在第二模型的输出不匹配,使用数据的独热编码修正第二预测概率中的对应行向量,得到第二修正概率。
为了描述的方便,描述以上系统时以功能分为各种模块分别描述。当然,在实施本公开时可以把各模块的功能在同一个或多个软件和/或硬件中实现。
上述实施例的系统用于实现前述任一实施例中相应的多聚合节点联邦知识蒸馏学习方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
基于同一发明构思,与上述任意实施例方法相对应的,本公开还提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述任意一实施例所述的多聚合节点联邦知识蒸馏学习方法。
图3示出了本实施例所提供的一种更为具体的服务器硬件结构示意图,该服务器可以包括:处理器1010、存储器1020、输入/输出接口1030、通信接口1040和总线1050。其中处理器1010、存储器1020、输入/输出接口1030和通信接口1040通过总线1050实现彼此之间在设备内部的通信连接。
处理器1010可以采用通用的CPU(Central Processing Unit,中央处理器)、微处理器、应用专用集成电路(Application Specific Integrated Circuit,ASIC)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本说明书实施例所提供的技术方案。
存储器1020可以采用ROM(Read Only Memory,只读存储器)、RAM(Random AccessMemory,随机存取存储器)、静态存储设备,动态存储设备等形式实现。存储器1020可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器1020中,并由处理器1010来调用执行。
输入/输出接口1030用于连接输入/输出模块,以实现信息输入及输出。输入输出/模块可以作为组件配置在服务器中(图中未示出),也可以外接于服务器以提供相应功能。其中输入设备可以包括键盘、鼠标、触摸屏、麦克风、各类传感器等,输出设备可以包括显示器、扬声器、振动器、指示灯等。
通信接口1040用于连接通信模块(图中未示出),以实现本服务器与其他设备的通信交互。其中通信模块可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信。
总线1050包括一通路,在服务器的各个组件(例如处理器1010、存储器1020、输入/输出接口1030和通信接口1040)之间传输信息。
需要说明的是,尽管上述电子设备仅示出了处理器1010、存储器1020、输入/输出接口1030、通信接口1040以及总线1050,但是在具体实施过程中,该电子设备还可以包括实现正常运行所必需的其他组件。此外,本领域的技术人员可以理解的是,上述电子设备中也可以仅包含实现本说明书实施例方案所必需的组件,而不必包含图中所示的全部组件。
上述实施例的电子设备用于实现前述任一实施例中相应的多聚合节点联邦知识蒸馏学习方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
基于同一发明构思,与上述任意实施例方法相对应的,本公开还提供了一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行如上任一实施例所述的多聚合节点联邦知识蒸馏学习方法。
本实施例的计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。
上述实施例的存储介质存储的计算机指令用于使所述计算机执行如上任一实施例所述的多聚合节点联邦知识蒸馏学习方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
所属领域的普通技术人员应当理解:以上任何实施例的讨论仅为示例性的,并非旨在暗示本申请的范围(包括权利要求)被限于这些例子;在本申请的思路下,以上实施例或者不同实施例中的技术特征之间也可以进行组合,步骤可以以任意顺序实现,并存在如上所述的本申请实施例的不同方面的许多其它变化,为了简明它们没有在细节中提供。
尽管已经结合了本申请的具体实施例对本申请进行了描述,但是根据前面的描述,这些实施例的很多替换、修改和变型对本领域普通技术人员来说将是显而易见的。
本申请实施例旨在涵盖落入所附权利要求的宽泛范围之内的所有这样的替换、修改和变型。因此,凡在本申请实施例的精神和原则之内,所做的任何省略、修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (10)
1.一种多聚合节点联邦知识蒸馏学习方法,其特征在于,包括:
获取聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数;
根据所述聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数,得到所述聚合节点的更新模型参数和聚合节点更新模型;
基于知识蒸馏,根据所述聚合节点更新模型,训练与所述聚合节点连接的客户端的个性化模型。
2.根据权利要求1所述的多聚合节点联邦知识蒸馏学习方法,其特征在于,所述方法还包括通过如下方法得到所述聚合节点的模型参数:
将所述客户端的数据划分为公共数据和私有数据;
根据与所述聚合节点连接的所有客户端的公共数据,得到所述公共数据集;
将所述私有数据作为训练集训练所述客户端的初始模型,得到第一客户端模型;
使用所述第一客户端模型对所述公共数据集进行预测,得到所述公共数据集中的每个数据对应的第一预测概率分布;
将每个数据对应的第一预测概率分布表示为行向量,再将所有数据对应的行向量作为列向量中的元素,得到第一预测概率;
对所述第一预测概率进行修正计算,得到第一修正概率;
对与所述聚合节点连接的所有客户端的第一修正概率进行聚合计算,得到聚合概率;
根据所述聚合概率设置损失函数,将所述公共数据集作为训练集训练所述聚合节点的初始模型,得到聚合节点模型及所述聚合节点的模型参数。
3.根据权利要求1所述的多聚合节点联邦知识蒸馏学习方法,其特征在于,所述根据所述聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数,得到所述聚合节点的更新模型参数,包括:
计算所述聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数的加权平均数,得到所述聚合节点的更新模型参数,计算公式如下:
其中,wi′为第i个聚合节点的所述更新模型参数,M为聚合节点的总数,Nm为第m个聚合节点的公共数据集数据量,tim为聚合节点连接关系矩阵中第i行第m列的元素,代表第i个聚合节点与第m个聚合节点的连接关系,wm为第m个聚合节点的模型参数。
4.根据权利要求2所述的多聚合节点联邦知识蒸馏学习方法,其特征在于,所述基于知识蒸馏,根据所述聚合节点更新模型,训练与所述聚合节点连接的客户端的个性化模型,包括:
使用所述聚合节点更新模型对所述公共数据集进行预测,得到所述公共数据集中的每个数据对应的第二预测概率分布;
将每个数据对应的第二预测概率分布表示为行向量,再将所有数据对应的行向量作为列向量中的元素,得到第二预测概率;
对所述第二预测概率进行修正计算,得到第二修正概率;
根据所述第二修正概率设置损失函数,将所述公共数据集作为训练集训练所述第一客户端模型,得到第二客户端模型;
将所述私有数据作为训练集训练所述第二客户端模型,得到所述客户端的个性化模型。
5.根据权利要求2所述的多聚合节点联邦知识蒸馏学习方法,其特征在于,所述使用所述第一客户端模型对所述公共数据集进行预测,得到第一预测概率;对所述第一预测概率进行修正计算,得到第一修正概率,包括:
使用所述第一客户端模型对所述公共数据集进行预测,得到第一模型输出和第一预测概率;
比较所述第一模型输出与所述公共数据集;响应于确定所述公共数据集中的第k个数据标签/>与该数据在所述第一模型输出/>中的标签不匹配,使用所述数据的独热编码/>修正所述第一预测概率中的对应行向量,得到所述第一修正概率,公式统一表示如下:
其中,为所述第一修正概率,
,为数据/>在所述第一修正概率中标签labelv对应的修正概率值,/>为所述第一预测概率,/>
,为数据/>在第一预测概率中标签labelv对应的概率值,i为聚合节点的序号,j为与第i个聚合节点连接的客户端的序号,Nlabel为类别数,l为由类别组成的1×Nlabel的向量,/>wij′为所述第一客户端模型的模型参数,/>为第i个聚合节点下第j个客户端在所述公共数据集中的第k个数据的第一掩膜数值,所述第一掩膜数值在所述数据与该数据在所述第一模型的输出匹配时为Nlabel×Nlabel的单位矩阵E(Nlabel),否则为Nlabel×Nlabel的零矩阵zeros(Nlabel)。
6.根据权利要求2所述的多聚合节点联邦知识蒸馏学习方法,其特征在于,所述对与所述聚合节点连接的所有客户端的第一修正概率进行聚合计算,得到聚合概率,包括:
计算所述所有客户端的第一修正概率的平均数,得到所述聚合概率。
7.根据权利要求4所述的多聚合节点联邦知识蒸馏学习方法,其特征在于,所述使用所述聚合节点更新模型对所述公共数据集进行预测,得到第二预测概率;对所述第二预测概率进行修正计算,得到第二修正概率,包括:
使用所述聚合节点更新模型对所述公共数据集进行预测,得到第二模型输出和第二预测概率;
比较所述第二模型输出与所述公共数据集;
响应于确定所述公共数据集中的第k个数据标签/>与该数据在所述第二模型输出/>中的标签不匹配,使用所述数据的独热编码/>修正所述第二预测概率中的对应行向量,得到所述第二修正概率,公式统一表示如下:
其中,为所述第二修正概率,
,为数据/>在所述第二修正概率中标签labelv对应的修正概率值,/>为所述第二预测概率,
,为数据/>在第二预测概率中标签labelv对应的概率值,i为多聚合节点中聚合节点的序号,Nlabel为类别数,l为由标签类别组成的1×Nlabel的向量,wi′为所述更新模型参数,/>为第i个聚合节点在所述公共数据集中的第k个数据的第二掩膜数值,所述第二掩膜数值在所述数据与该数据在所述第二模型的输出匹配时为Nlabel×Nlabel的单位矩阵E(Nlabel),否则为Nlabel×Nlabel的零矩阵zeros(Nlabel)。
8.一种多聚合节点联邦知识蒸馏学习系统,其特征在于,包括:
聚合节点模型参数获取模块,用于获取聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数;
聚合节点更新模型获取模块,用于根据所述聚合节点的模型参数和该聚合节点的相邻聚合节点的模型参数,得到所述聚合节点的更新模型参数和聚合节点更新模型;
知识蒸馏模块,用于基于知识蒸馏,根据所述聚合节点更新模型,训练与所述聚合节点连接的客户端的个性化模型。
9.一种电子设备,其特征在于,包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述程序时实现如权利要求1至7中任意一项所述的方法。
10.一种非暂态计算机可读存储介质,其特征在于,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使计算机执行如权利要求1至7中任意一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310324054.2A CN116562390A (zh) | 2023-03-29 | 2023-03-29 | 多聚合节点联邦知识蒸馏学习方法、系统、设备及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310324054.2A CN116562390A (zh) | 2023-03-29 | 2023-03-29 | 多聚合节点联邦知识蒸馏学习方法、系统、设备及介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116562390A true CN116562390A (zh) | 2023-08-08 |
Family
ID=87493624
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310324054.2A Pending CN116562390A (zh) | 2023-03-29 | 2023-03-29 | 多聚合节点联邦知识蒸馏学习方法、系统、设备及介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116562390A (zh) |
-
2023
- 2023-03-29 CN CN202310324054.2A patent/CN116562390A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20220391771A1 (en) | Method, apparatus, and computer device and storage medium for distributed training of machine learning model | |
CN110263921B (zh) | 一种联邦学习模型的训练方法及装置 | |
WO2020094060A1 (zh) | 推荐方法及装置 | |
CN110009486B (zh) | 一种欺诈检测的方法、系统、设备及计算机可读存储介质 | |
US11341415B2 (en) | Method and apparatus for compressing neural network | |
CN112232165B (zh) | 一种数据处理方法、装置、计算机及可读存储介质 | |
CN113572697A (zh) | 一种基于图卷积神经网络与深度强化学习的负载均衡方法 | |
CN113826117A (zh) | 来自神经网络的高效二元表示 | |
CN109214519B (zh) | 数据处理系统、方法和设备 | |
CN111935005B (zh) | 数据传输方法、装置、处理设备及介质 | |
CN113904915A (zh) | 一种基于物联网的电力通信智能故障分析方法及系统 | |
CN112988851A (zh) | 反事实预测模型数据处理方法、装置、设备及存储介质 | |
CN116797850A (zh) | 基于知识蒸馏和一致性正则化的类增量图像分类方法 | |
CN116738983A (zh) | 模型进行金融领域任务处理的词嵌入方法、装置、设备 | |
CN116562390A (zh) | 多聚合节点联邦知识蒸馏学习方法、系统、设备及介质 | |
CN113642654B (zh) | 图像特征的融合方法、装置、电子设备和存储介质 | |
CN114912627A (zh) | 推荐模型训练方法、系统、计算机设备及存储介质 | |
CN112784967B (zh) | 信息处理方法、装置以及电子设备 | |
KR102258206B1 (ko) | 이종 데이터 융합을 이용한 이상 강수 감지 학습 장치, 이상 강수 감지 학습 방법, 이종 데이터 융합을 이용한 이상 강수 감지 장치 및 이상 강수 감지 방법 | |
WO2021115269A1 (zh) | 用户集群的预测方法、装置、计算机设备和存储介质 | |
CN113868523A (zh) | 推荐模型训练方法、电子设备及存储介质 | |
CN110147804B (zh) | 一种不平衡数据处理方法、终端及计算机可读存储介质 | |
CN114581946B (zh) | 人群计数方法、装置、存储介质及电子设备 | |
CN116958149B (zh) | 医疗模型训练方法、医疗数据分析方法、装置及相关设备 | |
CN112446464B (zh) | 一种神经网络卷积运算方法、装置以及相关产品 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |