CN115511109A - 一种高泛化性的个性化联邦学习实现方法 - Google Patents

一种高泛化性的个性化联邦学习实现方法 Download PDF

Info

Publication number
CN115511109A
CN115511109A CN202211206093.4A CN202211206093A CN115511109A CN 115511109 A CN115511109 A CN 115511109A CN 202211206093 A CN202211206093 A CN 202211206093A CN 115511109 A CN115511109 A CN 115511109A
Authority
CN
China
Prior art keywords
model
client
training
local
global
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211206093.4A
Other languages
English (en)
Inventor
王建新
刘渊
沈成超
盛韬
王殊
段桂华
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hunan Huaxin Software Co ltd
Original Assignee
Central South University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Central South University filed Critical Central South University
Priority to CN202211206093.4A priority Critical patent/CN115511109A/zh
Publication of CN115511109A publication Critical patent/CN115511109A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Machine Translation (AREA)

Abstract

本发明公开了一种高泛化性的个性化联邦学习实现方法,包括,服务端随机初始化全局双分支模型并发送初始化参数至客户端;客户端初始化本地双分支模型并利用本地数据进行本地迭代训练得到更新的客户端本地模型;将更新后的客户端本地模型训练的统计参数和全局任务分支的模型参数上传至服务端;服务端聚合所有客户端的全局任务分支的模型参数并更新发送给多个客户端;客户端根据服务端更新的全局任务分支模型参数并结合本地迭代训练得到的个性化任务分支模型参数,构成更新的客户端本地双分支模型;客户端使用本地双分支模型基于本地数据迭代训练并循环参与联邦更新直至满足预设标准。可在保证个性化联邦学习有效性的同时提升模型的泛化性。

Description

一种高泛化性的个性化联邦学习实现方法
技术领域
本发明涉及联邦学习技术领域,尤其涉及一种高泛化性的个性化联邦学习实现方法。
背景技术
联邦学习是指多个相互隔离的孤岛数据集上训练模型的任务,在愈加严格的隐私政策的要求下,传统中心式汇聚多个数据孤岛的数据来进行数据挖掘的方式变得不可行,而单个数据孤岛的有效数据不足,数据驱动的建模和数据挖掘受到限制,此时联邦学习便能发挥作用。通用联邦学习是指,所有客户端在不共享数据的情况下,共同训练一个共识模型,以尽可能地学到来自多个客户端数据的知识。通用联邦学习步骤主要包括:客户端选择、模型分发、模型训练和模型聚合,通过迭代直至收敛得到一个聚合的共识模型。
由于联邦学习数据隔离的固有属性,客户端的数据分布不可知,不同客户端模型的学习存在很强的异质性,如通过来自不同地理环境的客户端解决不同的任务客户端,但是此时聚合的共识模型偏向某些客户端从而整体表现不佳。为了处理客户端之间的这种异质性,个性化联邦学习允许每个客户端保留并优化独立的个性化模型,而不是使用全局的共识模型。旨在客户端从联邦学习中获得收益的同时,在本地可见的数据上有更好的表现,即个性化模型的表现优于客户端孤岛式独自训练产生的模型,同时优于联邦共识模型。
虽然个性化联邦学习方法为联邦客户端的异质性困境提供了解决方案,但是主流的个性化联邦学习实现方法侧重于在可见数据的性能提升。由于对可见数据的进一步优化,大多数主流方法生成的个性化模型容易过拟合,最终导致较强模型偏向性和模型泛化性降低。然而,模型泛化性是现实场景中需要关注的问题,例如,医院客户端接收来自未知医院的转诊患者的数据,不仅能关注联邦模型在本地可见数据的表现,还可以关注其在未知分布数据上的性能。
因此,亟需一种可侧重于模型的泛化性的个性化联邦学习实现方法,在保证个性化联邦学习有效性的同时,还可以提升模型的泛化性。
发明内容
针对背景技术中的问题,本发明提供了一种高泛化性的个性化联邦学习实现方法,利用任务独立的个性化批归一化和全局批归一化特征,通过双分支结构同时学习模型的个性化能力和泛化能力,即不仅能有效地提升客户端本地模型面对未知数据的泛化能力,还能保证客户端本地模型在客户端本地数据分布下的个性化能力。
第一方面,本发明提供了一种高泛化性的个性化联邦学习实现方法,包括,
步骤1:服务端随机初始化双分支结构的全局模型,将得到的初始化模型参数发送至多个选定的客户端;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
步骤2:每个客户端利用服务端发送的初始化模型参数,初始化双分支结构的客户端本地模型,并利用本地数据进行第一轮本地迭代训练,得到更新后的客户端本地模型;将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端;
步骤3:服务端将所有客户端的全局任务子模型的模型参数进行加权平均计算得到聚合后新的全局任务子模型的模型参数,并将更新后的模型参数发送给多个所选客户端;
步骤4:客户端利用服务端发送的全局任务子模型的模型参数,更新客户端本地模型中的全局任务子模型的模型参数,结合本轮联邦训练中迭代训练得到的客户端本地模型中的个性化任务子模型,得到更新的客户端本地模型,完成一轮联邦训练;
步骤5:客户端使用步骤4更新的客户端本地模型基于本地数据进行再一轮迭代训练,更新客户端本地模型参数,并将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端,返回步骤3,循环更新客户端本地模型直至满足预设标准。
进一步地,服务端使用的全局模型和客服端使用的客户端本地模型结构相同,即模型的特征提取层后添加批归一化层;其中,特征提取层为任务共享层,批归一化层为任务特定层;任务特定层包括全局批归一化层和个性化批归一化层。
进一步地,全局任务子模型由任务共享层和全局批归一化层构成;个性化任务子模型由任务共享层和个性化批归一化层构成。
进一步地,所述统计参数包括客户端参与训练的数据量。
优选地,步骤2中本地迭代训练得到更新后的客户端本地模型的过程具体为:
将本地数据x输入客户端本地模型后同时执行两个分支得到两个任务的输出,即全局任务输出yg和个性化任务输出yl,通过计算交叉熵损失分别得到全局任务损失lossg和个性化任务损失lossl
交叉熵损失的表达式如下:
Figure BDA0003873865290000021
其中,a取g或l;yj为预测目标,
Figure BDA0003873865290000022
是实际预测结果;m表示参与训练的客户端的数量;
利用全局任务损失lossg和个性化任务损失lossl得到总体损失lossoverall,表示式为:
lossoverall=αlossg+(1-α)lossl
其中,α为损失比例系数;
结合总体损失和预设的学习率η,客户端通过随机梯度下降和反向传播得到更新的计算模型整体的梯度,得到更新的客户端本地模型的模型参数,客户端本地模型的模型参数更新表达式如下:
Figure BDA0003873865290000031
其中,gl表示个性化任务子模型优化得到的一次迭代的总体梯度;gg分别表示全局任务子模型优化得到的一次迭代的总体梯度;
Figure BDA0003873865290000032
wg表示全局任务子模型的模型参数;wl表示个性化任务子模型的模型参数;t表示当前联邦训练的轮次;i表示第i个客户端。
优选地,步骤3中通过加权平均计算得到聚合后新的全局任务子模型的模型参数具体为:
计算客户端参与训练的数据量占所有客户端参与训练数据总量的比重;
全局任务子模型的模型参数wg的更新公式如下:
Figure BDA0003873865290000033
其中,K表示参与训练的客户端总数;k表示第k个客户端;n表示所有客户端参与训练的数据总量;nk表示第k个客户端训练的数据量;
Figure BDA0003873865290000034
表示第k个客户端在第t轮联邦训练中的全局任务子模型的模型参数;wg,t+1表示第k个客户端在第t+1轮联邦训练中的全局任务子模型的模型参数。
优选地,步骤5中预设标准具体为:
根据损失曲线对数据和客户端分布进行判断:
若为稳定收敛的数据和客户端分布时,客户端本地模型经过预设轮次的联邦训练后,将最后一个轮次的模型参数作为训练结果;
若为不能稳定收敛的数据和客户端分布时,通过将数据集中的验证集添加到联邦训练中,选取预设联邦训练轮次内在验证集中表现最优的客户端本地模型的模型参数作为训练结果。
进一步地,在训练过程中:
通过计算训练集的多任务损失以及反向传播即可更新模型参数,若需要对模型进行预测推理步骤,比如计算训练集、验证集和测试集的准确度,需要对输入数据的类别进行推理时,在推理阶段使用集成推理方法来得到客户端本地模型对输入数据的输出结果;
其中,集成推理方法具体如下:
本地数据输入到全局任务子模型,输出概率形式的全局任务输出yg
本地数据输入到个性化任务子模型,输出概率形式的个性化任务输出yp
比较上述两个子模型输出的所有类别对应的概率值,选择其中最大的概率值对应的类别作为客户端本地模型的分类结果,计算模型的准确率。
第二方面,本发明提供了一种高泛化性的个性化联邦学习实现方法,应用于服务端,包括:
Step1:服务端随机初始化双分支结构的全局模型,与参与训练的客户端生成连接,并将初始化模型参数发送给参与训练的客户端,等待客户端训练;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
Step2:接收所有参与训练的客户端上传的全局模型训练的统计参数、全局任务子模型的模型参数和客户端客户端本地模型的评价结果;
Step3:若联邦训练轮次或聚合的评价结果满足预设标准,则停止联邦训练;若联邦训练轮次和聚合的评价结果不满足预设标准,则将所上传的全局任务子模型参数进行加权平均计算得到聚合后的全局任务子模型的模型参数,并将聚合后的模型参数发送给参与训练的客户端,等待客户端训练,返回执行Step2,进行循环更新;所述聚合的评价结果是指服务端对各参与训练的客户端上传的客户端本地模型的评价结果进行聚合后的最终结果。
第三方面,一种高泛化性的个性化联邦学习实现方法,应用于客户端,包括:
S1:与服务端生成连接,接收服务端发送的初始化模型参数对客户端本地模型进行初始化;其中,客户端本地模型包括全局任务子模型分支和个性化任务子模型分支;
S2:利用本地数据对客户端本地模型进行进行一轮迭代训练,得到客户端本地模型的模型参数,将客户端本地模型中的统计参数、客户端本地模型是否满足预设标准的评价结果和全局子模型的模型参数上传至服务端;
S3:等待服务端发送结束训练指令,若指令为结束训练,进而结束训练并保存预设最佳的客户端本地模型;若指令为继续训练,则等待服务端对全局子模型的模型参数进行聚合,接收服务器发送的聚合后的全局子模型参数,更新客户端本地模型中的全局子模型参数,返回S2,进行循环更新。
有益效果
本发明提供了一种高泛化的个性化联邦学习实现方法,所述方法利用双分支结构的全局模型,通过全局任务子模型和个性化化任务子模型同时学习全局泛化任务和局部个性化任务,利用任务之间的相关性相互促进,有效提升了客户端本地模型对未知分布数据的性能表现,改善了客户端本地模型泛化性差的问题。
参与训练的客户端利用本地数据对客户端本地模型进行训练,将全局子模型参数上传服务端,服务端对所有参与客户端上传的全局任务子模型参数进行联邦聚合,有效的降低了联邦聚合对个性化特征学习的冲突,增强了客户端的客户端本地模型对个性化特征的学习,并且保留了全局特征,在未增加额外联邦通信轮次、局部训练轮次和训练模型的条件下,同时完成个性化特征和全局特征的学习,在提高客户端本地模型的泛化性的同时保证客户端本地模型的个性化性能。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明所述方法提供的双分支结构的全局模型结构图;
图2是本发明所述方法提供的服务端和客户端的通信示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,下面将对本发明的技术方案进行详细的描述。显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所得到的所有其它实施方式,都属于本发明所保护的范围。
本发明提供了一种高泛化的个性化联邦学习实现方法,用于解决个性化联邦学习在图像中未知分布的数据上表现差的问题,并同时关注联邦模型在本地图像数据的性能表现。本发明提供的技术方案适用于不同的神经网络模型,即在特征提取层后添加批归一化层,可根据不同的需求选取神经网络模型的类型。下面结合附图和具体实施例对本发明中技术方案作进一步详细的说明。
实施例1
如图1-2所示,本实施例提供了一种高泛化的个性化联邦学习方法,本实施例中选取神经网络中的卷积神经网络对图像进行分类任务为例,包括如下步骤:
步骤1:服务端随机初始化双分支结构的全局模型,将得到的初始化模型参数发送至多个选定的客户端;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支。
所有客户端使用相同的模型结构,且与服务端使用的模型结构相同。当服务端发送初始化模型参数时,使用相同模型结构的客户端承载服务端的初始化模型参数,可实现各个客户端具备相同的初始化状态。本发明的技术方案中模型结构采用双分支结构的全局模型,模型结构如图1所示。
图1中参与联邦训练的模型包含卷积层和全连接层和其他神经网络层等。由于数据隔离是联邦学习的固有属性,不同客户端的数据分布存在常见的几种非独立同分布的情景,如类别分布不均衡和特征分布偏移等。现实应用中,不同医疗机构使用不同的设备产生的影像,就可能产生特征分布偏移。为了适应不同客户端的特征分布情况,本发明在特征提取层后添加批归一化层,具体在图1中表现为在每个卷积层和全连接层后面(除最后一个用于分类的全连接层)添加批归一化层。其实现方式是:通过统计一个批次训练数据的均值和方差;将两个可训练的参数将输出进行归一化,使其为预设合理范围内的一个数值,最终实现对原始输入产生特定于某个批次数据分布的偏移。不同客户端所持有的数据分布存在差异,因此,批归一化层训练得到的参数也不同。本发明使用独立的批归一化层来实现个性化联邦模型。
由于个性化模型的低泛化性,本发明设计了全局任务以增强泛化性。针对不同的任务,本发明在原神经网络模型基础上,通过在特征提取层之后的相同位置设计两个用于不同任务的批归一化层,其中全局任务的批归一化层用于联邦训练中,个性化任务的批归一化层用于形成客户端的个性化模型参数。本实施例中,模型训练存在两个任务分支,分别对应一个子模型,子模型由任务共享层和任务特定层组成,任务共享层由两个分支共同优化,任务特定层则单独优化。全局任务子模型由任务共享层和全局批归一化层构成;个性化任务子模型由任务共享层和个性化批归一化层构成,由此构成了双分支结构的全局模型。超参数包括联邦训练轮次、本地迭代轮次、损失比例系数、学习率、数据批大小。
步骤2:每个客户端利用服务端发送的初始化模型参数,初始化双分支结构的客户端本地模型,并利用本地数据进行第一轮本地迭代训练,得到更新后的客户端本地模型;将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端。
客户端使用初始化后的客户端本地模型进行一轮联邦训练的本地迭代训练,其具体过程为:
客户端i使用本地私有的图像数据
Figure BDA0003873865290000061
其中,ni表示客户端i将ni个样本的集合作为训练数据;xj表示训练数据中第j个输入的本地图像数据;yj表示训练数据第j个本地图像数据的真实标签。在具体实施时,可根据实际需求对本地迭代轮次进行设置,本实例中选择1次本地迭代以避免局部模型过拟合局部数据。
使用wg表示全局任务子模型,wl表示个性化任务子模型。将本地图像数据x输入客户端本地模型后同时执行两个分支得到两个任务的输出,即全局任务输出yg和个性化任务输出yl,通过计算交叉熵损失分别得到全局任务损失lossg和个性化任务损失lossl
交叉熵损失的表达式如下:
Figure BDA0003873865290000071
其中,a取g或l;yj为预测目标,
Figure BDA0003873865290000072
是实际预测结果;m表示参与训练的客户端的数量;利用全局任务损失lossg和个性化任务损失lossl得到总体损失lossoverakl,表示式为:
lossoverakl=αlossg+(1-α)lossl
其中,α为损失比例系数。
结合总体损失和预设的学习率η,客户端通过随机梯度下降和反向传播得到更新的计算模型整体的梯度,得到更新的客户端本地模型的模型参数。其中,在全局任务损失lossg计算梯度时,个性化任务特定的批归一化层的梯度为0;同理,在个性化任务损失lossl计算梯度时,全局任务特定的批归一化层梯度为0。也就是说,总体损失losspverall对模型的优化相当于同时优化两个子模型,任务共享层的参数由两个损失共同优化,任务特定层由两个损失单独优化。客户端本地模型的模型参数更新表达式如下:
Figure BDA0003873865290000073
其中,gl表示个性化任务子模型优化得到的一次迭代的总体梯度;gg分别表示全局任务子模型优化得到的一次迭代的总体梯度;
Figure BDA0003873865290000074
wg表示全局任务子模型的模型参数;wl表示个性化任务子模型的模型参数;t表示当前联邦训练的轮次;i表示第i个客户端。
步骤3:服务端将所有客户端的全局任务子模型的模型参数进行加权平均计算得到聚合后新的全局任务子模型的模型参数,并将更新后的模型参数发送给多个所选客户端。
其中,通过加权平均计算得到聚合后新的全局任务子模型的模型参数具体为:
计算客户端参与训练的数据量占所有客户端参与训练数据总量的比重;
全局任务子模型的模型参数wg的更新公式如下:
Figure BDA0003873865290000075
其中,K表示参与训练的客户端总数;k表示第k个客户端;n表示所有客户端参与训练的数据总量;nk表示第k个客户端训练的数据量;
Figure BDA0003873865290000076
表示第k个客户端在第t轮联邦训练中的全局任务子模型的模型参数;wg,t+1表示第k个客户端在第t+1轮联邦训练中的全局任务子模型的模型参数。
值得注意的是,全局任务子模型的联邦聚合与本地双分支结构的模型的本地迭代训练过程是解耦合的,全局任务子模型用于学习数据的一致性知识,也可以根据实际需求采用其他的聚合方法改善一致性特征的学习。
步骤4:客户端利用服务端发送的全局任务子模型的模型参数,更新客户端本地模型中的全局任务子模型的模型参数,结合本轮联邦训练中迭代训练得到的客户端本地模型中的个性化任务子模型,得到更新的客户端本地模型,完成一轮联邦训练;
其中,客户端客户端本地模型的全局子模型的模型参数来自步骤3中服务端聚合所有客户端的全局子模型的模型参数后更新的数据,包括任务共享层和全局任务特定的批归一化层;而个性化任务特定的批归一化层保持步骤2通过本地图像数据训练后的局部更新的个性化子模型的模型参数,不同客户端模型形成差异化,因此此步骤生成的客户端本地全局模型为个性化模型。
步骤5:客户端使用步骤4更新的客户端本地模型基于本地数据进行再一轮迭代训练,更新客户端本地模型参数,并将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端,返回步骤3,循环更新客户端本地模型直至满足预设标准。
其中,预设标准具体为:
根据损失曲线对数据和客户端分布进行判断:
若为稳定收敛的数据和客户端分布时,客户端本地模型经过预设轮次的联邦训练后,将最后一个轮次的模型参数作为训练结果;
若为不能稳定收敛的数据和客户端分布时,通过将数据集中的验证集添加到联邦训练中,选取预设联邦训练轮次内的验证集中表现最优的客户端本地模型的模型参数作为训练结果。
进一步地,可以稳定收敛指模型的训练损失在一定轮次后变化较小,如手写数字分类为可以稳定收敛的数据和客户端分布;不能稳定收敛指模型的训练损失在较多联邦轮次后变化仍然较大。其中,验证集从训练集中按一定比例划分,不参与训练,仅用于选择训练的模型参数。
在训练过程中:
通过计算训练集的多任务损失以及反向传播即可更新模型参数,若需要对模型进行预测推理步骤,比如计算训练集、验证集和测试集的准确度,需要对输入数据的类别进行推理时,在推理阶段使用集成推理方法来得到客户端本地模型对输入数据的输出结果;
其中,集成推理方法具体如下:
将本地图像数据输入到全局任务子模型,输出概率形式的全局任务输出yg
将本地图像数据输入到个性化任务子模型,输出概率形式的个性化任务输出yp
比较上述两个子模型输出的所有类别对应的概率值,选择其中最大的概率值对应的类别作为客户端本地模型的分类结果,计算模型的准确率。若在无标签的预测任务上,同样进行以上三个步骤以输出预测结果。
应用实例:
本发明以5个客户端分别持有不同特征分布的手写数字数据为例,首先将每个客户端按比例7:3划分训练数据和测试数据,根据客户端是否参与联邦训练形成不同的客户端数据状态:联邦可见数据,联邦不可见数据。其中,采用留一法循环从5个客户端中选择4个客户端参与联邦训练,分别提供用于训练模型的训练数据和用于评估模型的测试数据,测试数据为对应客户端的可见数据;留下1个客户端不参与训练仅提供测试数据,该数据为联邦不可见数据。参与联邦训练的4个客户端分别生成客户端本地模型,在不可见数据上的表现即为模型的泛化性,在其对应客户端可见数据的表现则为模型的个性化性能。每个客户端分别有743张不重叠的手写数字图像。本地客户端本地模型选择如图1所示的卷积神经网络进行实际应用。
当所述方法应用于求解不同医院之间的数据孤岛问题时,每个客户端可以视为一个独立的医院,同时不同医院具有不同的数据分布。综上所述,本发明中一个联邦中客户端的知识可以为其他客户端所理解,而无需显式地共享其私有数据,通过联邦聚合和个性化优化,进一步发掘各个参与方的数据价值,提高模型训练的收敛性、鲁棒性和泛化性。
实施例2
本实施例提供了一种高泛化性的个性化联邦学习实现方法,应用于服务端,包括:
Step1:服务端随机初始化双分支结构的全局模型,与参与训练的客户端生成连接,并将初始化模型参数发送给参与训练的客户端,等待客户端利用本地图像数据进行训练;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
Step2:接收所有参与训练的客户端上传的全局模型训练的统计参数、全局任务子模型的模型参数和客户端客户端本地模型的评价结果;
Step3:若联邦训练轮次或聚合的评价结果满足预设标准,则停止联邦训练;若联邦训练轮次和聚合的评价结果不满足预设标准,则将所上传的全局任务子模型参数进行加权平均计算得到聚合后的全局任务子模型的模型参数,并将聚合后的模型参数发送给参与训练的客户端,等待客户端训练,返回执行Step2,进行循环更新;所述聚合的评价结果是指服务端对各参与训练的客户端上传的客户端本地模型的评价结果进行聚合后的最终结果。本实施例中客户端客户端本地模型的评价结果的指标包括训练集、验证集、测试集的准确率。
实施例3
本实施例提供了一种高泛化性的个性化联邦学习实现方法,应用于客户端,包括:
S1:与服务端生成连接,接收服务端发送的初始化模型参数对客户端本地模型进行初始化;其中,客户端本地模型包括全局任务子模型分支和个性化任务子模型分支;
S2:利用本地图像数据对客户端本地模型进行进行一轮迭代训练,得到客户端本地模型的模型参数,将客户端本地模型中的统计参数、客户端本地模型是否满足预设标准的评价结果和全局子模型的模型参数上传至服务端;
S3:等待服务端发送结束训练指令,若指令为结束训练,进而结束训练并保存预设最佳的客户端本地模型;若指令为继续训练,则等待服务端对全局子模型的模型参数进行聚合,接收服务器发送的聚合后的全局子模型参数,更新客户端本地模型中的全局子模型参数,返回S2,进行循环更新。
可以理解的是,上述各实施例中相同或相似部分可以相互参考,在一些实施例中未详细说明的内容可以参见其他实施例中相同或相似的内容。
尽管上面已经示出和描述了本发明的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本发明的限制,本领域的普通技术人员在本发明的范围内可以对上述实施例进行变化、修改、替换和变型。

Claims (10)

1.一种高泛化性的个性化联邦学习实现方法,其特征在于,包括,
步骤1:服务端随机初始化双分支结构的全局模型,将得到的初始化模型参数发送至多个选定的客户端;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
步骤2:每个客户端利用服务端发送的初始化模型参数,初始化双分支结构的客户端本地模型,并利用本地数据进行第一轮本地迭代训练,得到更新后的客户端本地模型;将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端;
步骤3:服务端将所有客户端的全局任务子模型的模型参数进行加权平均计算得到聚合后新的全局任务子模型的模型参数,并将更新后的模型参数发送给多个所选客户端;
步骤4:客户端利用服务端发送的全局任务子模型的模型参数,更新客户端本地模型中的全局任务子模型的模型参数,结合本轮联邦训练中迭代训练得到的客户端本地模型中的个性化任务子模型,得到更新的客户端本地模型,完成一轮联邦训练;
步骤5:客户端使用步骤4更新的客户端本地模型基于本地数据进行再一轮迭代训练,更新客户端本地模型参数,并将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端,返回步骤3,循环更新客户端本地模型直至满足预设标准。
2.根据权利要求1所述的高泛化性的个性化联邦学习实现方法,其特征在于,服务端使用的全局模型和客服端使用的客户端本地模型结构相同,包括模型的特征提取层和相应的批归一化层;其中,特征提取层为任务共享层,批归一化层为任务特定层;任务特定层包括全局批归一化层和个性化批归一化层。
3.根据权利要求2所述的高泛化性的个性化联邦学习实现方法,其特征在于,全局任务子模型由任务共享层和全局批归一化层构成;个性化任务子模型由任务共享层和个性化批归一化层构成。
4.根据权利要求1所述的高泛化性的个性化联邦学习实现方法,其特征在于,所述统计参数包括客户端参与训练的数据量。
5.根据权利要求1所述的高泛化性的个性化联邦学习实现方法,其特征在于,S2中本地迭代训练得到更新后的客户端本地模型的过程具体为:
将本地数据x输入客户端本地模型后同时执行两个分支得到两个任务的输出,即全局任务输出yg和个性化任务输出yl,通过计算交叉熵损失分别得到全局任务损失lossg和个性化任务损失lossl
交叉熵损失的表达式如下:
Figure FDA0003873865280000011
其中,a取g或l;yj为预测目标,
Figure FDA0003873865280000021
是实际预测结果;m表示参与训练的客户端的数量;
利用全局任务损失lossg和个性化任务损失lossl得到总体损失lossoverall,表示式为:
lossoverall=αlossg+(1-α)lossl
其中,α为损失比例系数;
结合总体损失和预设的学习率η,客户端通过随机梯度下降和反向传播得到更新的计算模型整体的梯度,得到更新的客户端本地模型的模型参数,客户端本地模型的模型参数更新表达式如下:
Figure FDA0003873865280000022
其中,gl表示个性化任务子模型优化得到的一次迭代的总体梯度;gg分别表示全局任务子模型优化得到的一次迭代的总体梯度;
Figure FDA0003873865280000023
wg表示全局任务子模型的模型参数;wl表示个性化任务子模型的模型参数;t表示当前联邦训练的轮次;i表示第i个客户端。
6.根据权利要求1所述的高泛化性的个性化联邦学习实现方法,其特征在于,S3中通过加权平均计算得到聚合后新的全局任务子模型的模型参数具体为:
计算客户端参与训练的数据量占所有客户端参与训练数据总量的比重;
全局任务子模型的模型参数wg的更新公式如下:
Figure FDA0003873865280000024
其中,K表示参与训练的客户端总数;k表示第k个客户端;n表示所有客户端参与训练的数据总量;nk表示第k个客户端训练的数据量;
Figure FDA0003873865280000025
表示第k个客户端在第t轮联邦训练中的全局任务子模型的模型参数;wg,t+1表示第k个客户端在第t+1轮联邦训练中的全局任务子模型的模型参数。
7.根据权利要求1所述的高泛化性的个性化联邦学习实现方法,其特征在于,S5中预设标准具体为:
根据损失曲线对数据和客户端分布进行判断:
若为稳定收敛的数据和客户端分布时,客户端本地模型经过预设轮次的联邦训练后,将最后一个轮次的模型参数作为训练结果;
若为不能稳定收敛的数据和客户端分布时,通过将数据集划分出的验证集添加到联邦训练中,选取预设联邦训练轮次内在验证集中表现最优的客户端本地模型的模型参数作为训练结果。
8.根据权利要求7所述的高泛化性的个性化联邦学习实现方法,其特征在于,在训练过程中:
对输入数据的输出结果进行推理时,在推理阶段使用集成推理方法来得到客户端本地模型对输入数据的输出结果;
其中,集成推理方法具体如下:
本地数据输入到全局任务子模型,输出概率形式的全局任务输出yg
本地数据输入到个性化任务子模型,输出概率形式的个性化任务输出yp
比较上述两个子模型输出的所有类别对应的概率值,选择其中最大的概率值对应的类别作为客户端本地模型的输出结果。
9.一种高泛化性的个性化联邦学习实现方法,应用于服务端,其特征在于,包括:
Step1:服务端随机初始化双分支结构的全局模型,与参与训练的客户端生成连接,并将初始化模型参数发送给参与训练的客户端,等待客户端训练;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
Step2:接收所有参与训练的客户端上传的客户端本地模型训练的统计参数、全局任务子模型的模型参数和客户端本地模型的评价结果;
Step3:若联邦训练轮次或聚合的评价结果满足预设标准,则停止联邦训练;若联邦训练轮次和聚合的评价结果不满足预设标准,则将所上传的全局任务子模型参数进行加权平均计算得到聚合后的全局任务子模型的模型参数,并将聚合后的模型参数发送给参与训练的客户端,等待客户端训练,返回执行Step2,进行循环更新;所述聚合的评价结果是指服务端对各参与训练的客户端上传的客户端本地模型的评价结果进行聚合后的最终结果。
10.一种高泛化性的个性化联邦学习实现方法,应用于客户端,其特征在于,包括:
S1:与服务端生成连接,接收服务端发送的初始化模型参数对客户端本地模型进行初始化;其中,客户端本地模型包括全局任务子模型分支和个性化任务子模型分支;
S2:利用本地数据对客户端本地模型进行进行一轮迭代训练,得到客户端本地模型的模型参数,将客户端本地模型中的统计参数、客户端本地模型是否满足预设标准的评价结果和全局子模型的模型参数上传至服务端;
S3:等待服务端发送指令,若指令为结束训练,进而结束训练并保存预设最佳的客户端本地模型;若指令为继续训练,则等待服务端对全局任务子模型的模型参数进行聚合,接收服务器发送的聚合后的全局任务子模型参数,更新客户端本地模型中的全局任务子模型参数,返回S2,进行循环更新。
CN202211206093.4A 2022-09-30 2022-09-30 一种高泛化性的个性化联邦学习实现方法 Pending CN115511109A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211206093.4A CN115511109A (zh) 2022-09-30 2022-09-30 一种高泛化性的个性化联邦学习实现方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211206093.4A CN115511109A (zh) 2022-09-30 2022-09-30 一种高泛化性的个性化联邦学习实现方法

Publications (1)

Publication Number Publication Date
CN115511109A true CN115511109A (zh) 2022-12-23

Family

ID=84509247

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211206093.4A Pending CN115511109A (zh) 2022-09-30 2022-09-30 一种高泛化性的个性化联邦学习实现方法

Country Status (1)

Country Link
CN (1) CN115511109A (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115796275A (zh) * 2023-01-05 2023-03-14 成都墨甲信息科技有限公司 基于区块链的联邦学习方法、装置、电子设备及存储介质
CN116522988A (zh) * 2023-07-03 2023-08-01 粤港澳大湾区数字经济研究院(福田) 基于图结构学习的联邦学习方法、系统、终端及介质
CN116541712A (zh) * 2023-06-26 2023-08-04 杭州金智塔科技有限公司 基于非独立同分布数据的联邦建模方法及系统
CN117708681A (zh) * 2024-02-06 2024-03-15 南京邮电大学 基于结构图指导的个性化联邦脑电信号分类方法及系统

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115796275A (zh) * 2023-01-05 2023-03-14 成都墨甲信息科技有限公司 基于区块链的联邦学习方法、装置、电子设备及存储介质
CN116541712A (zh) * 2023-06-26 2023-08-04 杭州金智塔科技有限公司 基于非独立同分布数据的联邦建模方法及系统
CN116541712B (zh) * 2023-06-26 2023-12-26 杭州金智塔科技有限公司 基于非独立同分布数据的联邦建模方法及系统
CN116522988A (zh) * 2023-07-03 2023-08-01 粤港澳大湾区数字经济研究院(福田) 基于图结构学习的联邦学习方法、系统、终端及介质
CN116522988B (zh) * 2023-07-03 2023-10-31 粤港澳大湾区数字经济研究院(福田) 基于图结构学习的联邦学习方法、系统、终端及介质
CN117708681A (zh) * 2024-02-06 2024-03-15 南京邮电大学 基于结构图指导的个性化联邦脑电信号分类方法及系统
CN117708681B (zh) * 2024-02-06 2024-04-26 南京邮电大学 基于结构图指导的个性化联邦脑电信号分类方法及系统

Similar Documents

Publication Publication Date Title
CN115511109A (zh) 一种高泛化性的个性化联邦学习实现方法
CN109902706B (zh) 推荐方法及装置
Chandra Competition and collaboration in cooperative coevolution of Elman recurrent neural networks for time-series prediction
US20190073580A1 (en) Sparse Neural Network Modeling Infrastructure
US20190073581A1 (en) Mixed Machine Learning Architecture
CN114415735B (zh) 面向动态环境的多无人机分布式智能任务分配方法
CN110263236B (zh) 基于动态多视图学习模型的社交网络用户多标签分类方法
CN117236421B (zh) 一种基于联邦知识蒸馏的大模型训练方法
CN113190688A (zh) 基于逻辑推理和图卷积的复杂网络链接预测方法及系统
CN115344883A (zh) 一种用于处理不平衡数据的个性化联邦学习方法和装置
CN115270001B (zh) 基于云端协同学习的隐私保护推荐方法及系统
CN109670927A (zh) 信用额度的调整方法及其装置、设备、存储介质
CN112292696A (zh) 确定执行设备的动作选择方针
CN114819091B (zh) 基于自适应任务权重的多任务网络模型训练方法及系统
CN114091659A (zh) 基于空时信息的超低延时脉冲神经网络及学习方法
Wang et al. Digital-twin-aided product design framework for IoT platforms
CN117574429A (zh) 一种边缘计算网络中隐私强化的联邦深度学习方法
CN117523291A (zh) 基于联邦知识蒸馏和集成学习的图像分类方法
Zou et al. FedDCS: Federated learning framework based on dynamic client selection
CN117273105A (zh) 一种针对神经网络模型的模块构建方法及装置
Tian et al. Synergetic focal loss for imbalanced classification in federated xgboost
CN116645130A (zh) 基于联邦学习与gru结合的汽车订单需求量预测方法
WO2022127603A1 (zh) 一种模型处理方法及相关装置
CN113535911B (zh) 奖励模型处理方法、电子设备、介质和计算机程序产品
CN115908600A (zh) 基于先验正则化的大批量图像重建方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
TA01 Transfer of patent application right
TA01 Transfer of patent application right

Effective date of registration: 20230504

Address after: Building G4, China Minmetals Lushan Science and Technology Innovation Park, No. 966 Lushan South Road, Yuelu Street, Yuelu District, Changsha City, Hunan Province, 410000

Applicant after: Hunan Huaxin Software Co.,Ltd.

Address before: Yuelu District City, Hunan province 410083 Changsha Lushan Road No. 932

Applicant before: CENTRAL SOUTH University