CN115222061A - 基于持续学习的联邦学习方法以及相关设备 - Google Patents

基于持续学习的联邦学习方法以及相关设备 Download PDF

Info

Publication number
CN115222061A
CN115222061A CN202210908742.9A CN202210908742A CN115222061A CN 115222061 A CN115222061 A CN 115222061A CN 202210908742 A CN202210908742 A CN 202210908742A CN 115222061 A CN115222061 A CN 115222061A
Authority
CN
China
Prior art keywords
sample data
model
learning
data set
local
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210908742.9A
Other languages
English (en)
Inventor
李泽远
王健宗
曹康养
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202210908742.9A priority Critical patent/CN115222061A/zh
Publication of CN115222061A publication Critical patent/CN115222061A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请涉及计算机技术领域,提供了一种基于持续学习的联邦学习方法以及相关设备,方法应用于包括服务器、多个客户端的联邦学习系统,服务器和客户端均存储有基于第一样本数据集训练得到的第一联邦学习模型,所述方法通过客户端获取第二样本数据集,第二样本数据集和第一样本数据集分别用于第一联邦学习模型不同的学习任务,并从第一样本数据中提取多个样本数据作为辅助样本数据,在模型学习第二样本数据集的同时,最小化学习多个辅助样本数据时产生的损失对本地模型进行训练。本申请实施例通过在模型学习新任务的同时,融入学习多个旧样本时产生的损失来修正模型梯度,保护模型学习到的旧知识,缓解灾难性遗忘。

Description

基于持续学习的联邦学习方法以及相关设备
技术领域
本申请涉及计算机技术领域,尤其涉及一种基于持续学习的联邦学习方法以及相关设备。
背景技术
联邦学习是一种打破数据孤岛、保护数据隐私的分布式机器学习技术,可以在不交换本地数据的情况下,多中心联合训练一个机器学习模型,相较于单中心数据训练的模型,联邦学习模型往往具有更高的分割性能和泛化性能。
然而,联邦学习在面对按时序到来的一系列任务的持续学习过程中,会出现全局模型在旧任务上的表现会随着新任务的学习而显著下降的情况,也就是出现灾难性遗忘现象。
发明内容
本申请实施例的主要目的在于提出一种基于持续学习的联邦学习方法、系统、电子设备及计算机可读存储介质,能够缓解联邦学习中模型的灾难性遗忘。
为实现上述目的,本申请实施例的第一方面提出了一种基于持续学习的联邦学习方法,所述方法应用于联邦学习系统,所述系统包括服务器、多个客户端,所述服务器分别与多个所述客户端通信连接,所述服务器和所述客户端均存储有基于第一样本数据集训练得到的第一联邦学习模型,所述方法包括:
通过所述客户端获取第二样本数据集,所述第二样本数据集和所述第一样本数据集分别用于所述第一联邦学习模型不同的学习任务;
通过所述客户端从所述第一样本数据集中提取多个样本数据作为辅助样本数据;
通过所述客户端将所述第一联邦学习模型作为本地模型,以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练;
通过所述客户端将训练后的所述本地模型上传至所述服务器;
通过所述服务器接收多个所述客户端上传的所述本地模型,对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型。
根据本发明一些实施例提供的基于持续学习的联邦学习方法,在所述通过所述客户端从所述第一样本数据集中提取多个样本数据作为辅助样本数据之后,所述方法还包括:
通过所述客户端将多个所述辅助样本数据输入至所述第一联邦学习模型,以通过所述第一联邦学习模型得到多个所述辅助样本数据对应的第一分类预测值;
所述以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练,包括:
将所述第二样本数据集以及多个所述辅助样本数据输入至所述本地模型,以通过所述本地模型得到所述第二样本数据集中每个样本数据以及多个所述辅助样本数据对应的第二分类预测值;
确定所述第二样本数据集中每个样本数据的真实分类标签,以最小化所述第二样本数据集中每个样本数据的第二分类预测值以及真实分类标签之间的差异为训练目标,确定第一损失函数;
以最小化每个所述辅助样本数据的第一分类预测值和第二分类预测值之间的差异为训练目标,确定第二损失函数;
基于所述第一损失函数和所述第二损失函数对所述本地模型进行训练。
根据本发明一些实施例提供的基于持续学习的联邦学习方法,
所述第一损失函数通过以下公式确定:
Figure BDA0003773370390000021
其中,所述LC为第一损失函数,所述M表示所述第二样本数据集中分类的类别数量,所述yC为独热编码向量,若所述样本数据的真实分类标签与分类C相同,则yC取1,否则取0,所述pC表征所述样本数据属于分类C的第二分类预测值。
根据本发明一些实施例提供的基于持续学习的联邦学习方法,所述第二损失函数通过以下公式确定:
Figure BDA0003773370390000022
其中,所述LMSE为第二损失函数,所述n为所述辅助样本数据的样本数量,所述y′为所述第二分类预测值,所述y为第一分类预测值。
根据本发明一些实施例提供的基于持续学习的联邦学习方法,在所述对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型之前,所述方法还包括:
获取每个所述客户端中所述第二样本数据集的样本数量;
所述对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型,包括:
根据每个所述客户端中所述第二样本数据集的样本数量,确定每个所述本地模型的第一权重系数;
根据每个所述本地模型的第一权重系数,对多个所述本地模型的模型参数进行加权平均处理,得到中间模型;
对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理,得到第二联邦学习模型。
根据本发明一些实施例提供的基于持续学习的联邦学习方法,所述对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理,得到第二联邦学习模型,包括:
获取动态的第二权重系数;
根据所述第二权重系数,对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理,得到第二联邦学习模型;
对所述第二联邦学习模型进行测试,若所述第二联邦学习模型的测试结果未满足预设的整合结束条件,则返回获取动态的第二权重系数这一步骤,直至所述第二联邦学习模型的测试结果满足所述整合结束条件。
根据本发明一些实施例提供的基于持续学习的联邦学习方法,在所述对所述第二联邦学习模型进行测试之前,所述方法还包括:
获取测试样本数据集;
所述对所述第二联邦学习模型进行测试,包括:
将所述测试样本数据集输入到所述第二联邦学习模型,以通过所述第二联邦学习模型得到所述测试样本数据集中每个测试样本数据对应的第三分类预测值;
确定所述测试样本数据集中每个测试样本数据的真实分类标签,根据所述测试样本数据对应的第三分类预测值以及真实分类标签,确定并记录所述第二联邦学习模型的评价指标;
其中,所述评价指标包括以下至少之一:戴斯相似性系数、交并比系数或准确率。
为实现上述目的,本申请实施例的第二方面提出了一种基于持续学习的联邦学习系统,所述系统包括服务器、多个客户端,所述服务器分别与多个所述客户端通信连接,所述服务器和所述客户端均存储有基于第一样本数据集训练得到的第一联邦学习模型;其中,
所述客户端包括:
获取模块,用于获取第二样本数据集,所述第一样本数据集和所述第二样本数据集分别用于所述第一联邦学习模型不同的学习任务;
提取模块,用于从所述第一样本数据集中提取多个样本数据作为辅助样本数据;
训练模块,用于将所述第一联邦学习模型作为本地模型,以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练;
发送模块,用于将训练后的所述本地模型上传至所述服务器;
所述服务器包括:
接收模块,用于接收多个所述客户端上传的所述本地模型;
整合模块,用于对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型。
为实现上述目的,本申请实施例的第三方面提出了一种电子设备,所述电子设备包括存储器、处理器、存储在所述存储器上并可在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现上述第一方面所述的方法。
为实现上述目的,本申请实施例的第四方面提出了一种存储介质,所述存储介质为计算机可读存储介质,用于计算机可读存储,所述存储介质存储有一个或者多个计算机程序,所述一个或者多个计算机程序可被一个或者多个处理器执行,以实现上述第一方面所述的方法。
本申请提出一种基于持续学习的联邦学习方法、系统、电子设备以及计算机可读存储介质,所述方法应用于联邦学习系统,所述系统包括服务器、多个客户端,服务器分别与多个客户端通信连接,所述服务器和所述客户端均存储有基于第一样本数据集训练得到的第一联邦学习模型,所述方法通过客户端获取与第一样本数据集用于第一联邦学习模型的不同学习任务的第二样本数据集,并从第一样本数据集中提取多个样本数据作为辅助样本数据,并通过客户端将第一联邦学习模型作为本地模型,并以最小化本地模型在学习第二样本数据集以及多个辅助样本数据时产生的损失为训练目标,基于第二样本数据集以及多个辅助样本数据对本地模型进行训练,之后通过服务器接收多个客户端上传的训练好的本地模型,并对第一联邦学习模型和多个本地模型进行整合处理,得到第二联邦学习模型。本申请实施例在本地模型学习新任务的同时,融入学习多个旧样本时产生的损失来修正模型梯度,以保护模型学习到的旧知识,缓解联邦学习中模型的灾难性遗忘。
附图说明
图1是本申请实施例提供的一种基于持续学习的联邦学习方法的流程示意图;
图2是本申请另一实施例提供的一种基于持续学习的联邦学习方法的流程示意图;
图3是本申请另一实施例提供的一种基于持续学习的联邦学习方法的流程示意图;
图4是本申请另一实施例提供的一种基于持续学习的联邦学习方法的流程示意图;
图5是本申请另一实施例提供的一种基于持续学习的联邦学习方法的流程示意图;
图6是本申请实施例提供的一种基于持续学习的联邦学习方法的实施环境图;
图7是本申请实施例提供的一种基于持续学习的联邦学习系统的结构示意图;
图8是本申请实施例提供的一种电子设备的硬件结构示意图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本申请,并不用于限定本申请。
需要说明的是,除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。
联邦学习是一种打破数据孤岛、保护数据隐私的分布式机器学习技术,可以在不交换本地数据的情况下,多中心联合训练一个机器学习模型,相较于单中心数据训练的模型,联邦学习模型往往具有更高的分割性能和泛化性能。
然而,联邦学习在面对按时序到来的一系列任务的过程中,会出现全局模型在旧任务上的表现会随着新任务的学习而显著下降的情况,也就是出现灾难性遗忘现象。
基于此,本申请实施例提供了一种基于持续学习的联邦学习方法、系统、电子设备及计算机可读存储介质,能够缓解联邦学习中模型的灾难性遗忘。
本申请实施例提供的一种基于持续学习的联邦学习方法、系统、电子设备及计算机可读存储介质,具体通过如下实施例进行说明。首先描述本申请实施例中的基于持续学习的联邦学习方法。
本申请实施例可以基于人工智能技术对相关的数据进行获取和处理。其中,人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。
人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、机器人技术、生物识别技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
本申请可用于众多通用或专用的计算机系统环境或配置中。例如:个人计算机、服务器计算机、手持设备或便携式设备、平板型设备、多处理器系统、基于微处理器的系统、置顶盒、可编程的消费电子设备、网络PC、小型计算机、大型计算机、包括以上任何系统或设备的分布式计算环境等等。本申请可以在由计算机执行的计算机可执行指令的一般上下文中描述,例如程序模块。一般地,程序模块包括执行特定任务或实现特定抽象数据类型的例程、程序、对象、组件、数据结构等等。也可以在分布式计算环境中实践本申请,在这些分布式计算环境中,由通过通信网络而被连接的远程处理设备来执行任务。在分布式计算环境中,程序模块可以位于包括存储设备在内的本地和远程计算机存储介质中。
请参见图6,图6是本申请实施例提供的一种基于持续学习的实施环境图,所述方法应用于联邦学习系统,系统包括服务器、多个客户端,服务器分别与多个客户端通信连接,服务器和客户端均存储有基于第一样本数据集训练得到的第一联邦学习模型,请参见图1,图1示出了本申请实施例提供的一种基于持续学习的联邦学习方法的流程示意图,如图1所示,所述基于持续学习的联邦学习方法包括但不限于步骤S110至S150。
步骤S110,通过所述客户端获取第二样本数据集,所述第二样本数据集和所述第一样本数据集分别用于所述第一联邦学习模型不同的学习任务。
步骤S120,通过所述客户端从所述第一样本数据集中提取多个样本数据作为辅助样本数据。
步骤S130,通过所述客户端将所述第一联邦学习模型作为本地模型,以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练。
步骤S140,通过所述客户端将训练后的所述本地模型上传至所述服务器。
步骤S150,通过所述服务器接收多个所述客户端上传的所述本地模型,对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型。
示例性的,第一联邦学习模型为医学图像分割模型,第一样本数据集为器官图像样本。因此,用于医学图像分割的第一联邦学习模型通过学习第一样本数据集后,即能够对其他器官图像样本进行较为准确的图像分割,将第一联邦模型对第一样本数据集的学习称为旧任务,将模型在旧任务中学习到的知识称为旧知识;通过客户端获取第二样本数据集,例如细胞图像样本,之后将第一联邦学习模型作为本地模型,利用第二样本数据集对本地模型进行训练,也就是让本地模型学习新任务,使得本地模型能够对其他细胞图像样本进行较为准确的图像分割。
可以理解的是,通过客户端获取与第一样本数据集用于不同学习任务的第二样本数据集,并从第一样本数据集中提取多个样本数据作为辅助样本数据,再将第一联邦学习模型作为本地模型,并基于第二样本数据集和多个辅助样本数据对本地模型进行训练,在本地模型对第二样本数据集进行学习的同时,融入学习多个辅助样本数据时产生的损失来修正本地模型的模型梯度,而服务器整合进行旧任务学习得到的第一联邦学习模型以及进行新任务学习得到的本地模型,得到在新旧任务上表现良好的第二联邦学习模型,实现联邦学习中全局模型在学习新任务的同时,保护模型的旧知识,缓解联邦学习中模型的灾难性遗忘,在客户端以及服务器的计算和存储资源有限的情况下进行持续学习。
需要说明的是,在步骤S120中,可以通过客户端预设目标数值,并从第一样本数据集中随机提取目标数值的样本数据作为辅助样本数据。
在一些实施例中,所述方法还包括:
获取所述第一样本数据集中每个样本数据的预测难度,并按照所述预测难度从大到小对所述第一样本数据集中的每个样本数据进行排序;
所述步骤S120包括:
按照排序从所述第一样本数据集中提取多个样本数据作为辅助样本数据。
需要说明的是,按照排序从第一样本数据集中提取多个样本数据作为辅助样本数据,也就是从第一样本数据集中提取预测难度大的多个样本数据作为辅助样本数据,并基于第二样本数据集以及多个辅助样本数据对本地模型进行训练。
可以理解的是,通过从第一样本数据集中选取预测难度大,有代表性的样本数据作为辅助样本数据,能够在保证本地模型对第一样本数据的模型性能,也就是保护模型的旧知识的同时,减少辅助样本数据的样本数量,提高模型的训练效率。
在一些实施例中,所述获取第一样本数据集中每个样本数据的预测难度包括:
将所述第一样本数据集输入至所述第一联邦学习模型,以通过所述第一联邦学习模型得到所述第一样本数据集中每个样本数据对应的分类预测值;
根据所述分类预测值,确定所述第一样本数据集中每个样本数据对应的信息熵;
所述按照所述预测难度从大到小对所述第一样本数据集中的每个样本数据进行排序,包括:
按照所述信息熵从大到小对所述第一样本数据集中的每个样本数据进行排序。
在一些实施例中,在所述步骤S120之后,在所述步骤S130之前,所述方法还包括:
通过所述客户端将多个所述辅助样本数据输入至所述第一联邦学习模型,以通过所述第一联邦学习模型得到多个所述辅助样本数据对应的第一分类预测值。
参见图2,图2示出了本申请实施例提供的一种基于持续学习的联邦学习方法的流程示意图,如图2所示,所述以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练,包括但不限于步骤S210至S240。
步骤S210,将所述第二样本数据集以及多个所述辅助样本数据输入至所述本地模型,以通过所述本地模型得到所述第二样本数据集中每个样本数据以及多个所述辅助样本数据对应的第二分类预测值。
步骤S220,确定所述第二样本数据集中每个样本数据的真实分类标签,以最小化所述第二样本数据集中每个样本数据的第二分类预测值以及真实分类标签之间的差异为训练目标,确定第一损失函数。
步骤S230,以最小化每个所述辅助样本数据的第一分类预测值和第二分类预测值之间的差异为训练目标,确定第二损失函数。
步骤S240,基于所述第一损失函数和所述第二损失函数对所述本地模型进行训练。
可以理解的是,以最小化本地模型学习第一样本数据集时产生的损失为训练目标,也就是通过学习第一样本数据集时产生的损失来修正本地模型的模型梯度,以提高本地模型对第一样本数据集的模型性能。具体的,可以通过本地模型得到第二样本数据集中每个样本数据对应的第二分类预测值,以及确定第二样本数据集中每个样本数据对应的真实分类标签,于是最小化本地模型学习第一样本数据集时产生的损失为训练目标,也就是最小化第二样本数据集中每个样本数据的第二分类预测值以及真实分别标签之间的差异为目标,确定第一损失函数,基于第一损失函数对本地模型进行训练。
可以理解的是,在最小化本地模型学习第二样本数据集时产生的损失的同时,融入以最小化本地模型在学习多个辅助样本数据时产生的损失为训练目标,也就是融入学习多个辅助样本数据时产生的损失共同修正本地模型的模型梯度,以减少本地模型和第一联邦学习模型在第一样本数据集上的模型性能之间的差异。具体的,可以通过第一联邦学习模型得到多个辅助样本数据对应的第一分类预测值,而在本地模型对第二样本数据集的学习过程中,通过本地模型得到多个辅助样本数据对应的第二分类预测值,于是以最小化本地模型在学习多个辅助样本数据时产生的损失为训练目标,也就是以最小化每个辅助样本数据的第一分类预测值和第二分类预测值之间的差异为目标,确定第二损失函数,基于第二损失函数对本地模型进行训练。
在一些实施例中,所述第一损失函数通过以下公式确定:
Figure BDA0003773370390000101
其中,所述LC为第一损失函数,所述M表示所述第二样本数据集中分类的类别数量,所述yC为独热编码向量,若所述样本数据的真实分类标签与分类C相同,则yC取1,否则取0,所述pC表征所述样本数据属于分类C的第二分类预测值。
在一些实施例中,所述第二损失函数通过以下公式确定:
Figure BDA0003773370390000102
其中,所述LMSE为第二损失函数,所述n为所述辅助样本数据的样本数量,所述y′为所述第二分类预测值,所述y为第一分类预测值。
在一个具体实施例中,根据第一损失函数和第二损失函数构建总损失函数,基于总损失函数对本地模型进行训练,所述总损失函数通过以下公式确定:
Figure BDA0003773370390000103
需要说明的是,还可以根据实际应用场景,获取预设的损失权重系数,以确定第一损失函数和第二损失函数在总损失函数中的占比,从而调整学习第一样本数据集和多个辅助样本数据时产生的损失对本地模型的影响,提高本申请实施例提供的联邦学习方法对不同应用场景的适应性。
还需要说明的是,上述实施例提供了确定第二样本数据集中每个样本数据的第二分类预测值以及真实分别标签之间差异以及每个辅助样本数据的第一分类预测值和第二分类预测值之间差异的第一损失函数和第二损失函数,其中第一损失函数具体为交叉熵函数,第二损失函数为均方差函数,应了解,本实施例还可以使用其他类型的损失函数进行差异确定,在此不一一举例。
在一些实施例中,在所述对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型之前,所述方法还包括:
获取每个所述客户端中所述第二样本数据集的样本数量。
请参见图3,图3示出了本申请实施例提供的一种基于持续学习的联邦学习方法的流程示意图,如图3所示,所述对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型,包括但不限于步骤S310至S330。
步骤S310,根据每个所述客户端中所述第二样本数据集的样本数量,确定每个所述本地模型的第一权重系数。
步骤S320,根据每个所述本地模型的第一权重系数,对多个所述本地模型的模型参数进行加权平均处理,得到中间模型。
步骤S330,对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理,得到第二联邦学习模型。
可以理解的是,在客户端训练好的本地模型在学习第二样本训练集后,也就是学习新任务后,保留了一定的学习第一样本数据集的旧知识,而服务器通过对第一联邦学习模型和多个本地模型进行整合处理,相比较只整合客户端上传的学习新任务后的本地模型,能进一步保护第二联邦学习模型上在第一样本数据集上学到的旧知识。
可以理解的是,根据每个客户端中的样本数量,确定每个本地模型对应的第一权重系数,从而合理地调整每个客户端上的本地模型对服务器端中的第二联邦学习模型的影响大小,能够在多个客户端之间存在样本数量分布不均,而导致训练得到的多个本地模型之间参数差异大的情况下,避免由服务器整合得到的第二联邦模型的模型性能下降的问题。
示例性的,在步骤S310至S330中,一共有m个客户端(集合为V)向服务器上传了模型参数为
Figure BDA0003773370390000111
的本地模型,服务器获取每个客户端中第二样本数据集的样本数量nk,求和得到m个客户端的样本数据总量,之后根据每个客户端中第二样本数据集的样本数量nk,确定每个本地模型的第一权重系数nk/n,再根据每个本地模型的第一权重系数,对多个本地模型的模型参数进行加权平均处理,得到中间模型,其具体公式为:
Figure BDA0003773370390000112
其中,所述wt+1为中间模型的模型参数。
在一些实施例中,请参见图4,图4为图3中步骤S330的子步骤流程图,如图4所示,所述步骤S330包括但不限于步骤S410至S430。
步骤S410,获取动态的第二权重系数。
步骤S420,根据所述第二权重系数,对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理,得到第二联邦学习模型。
步骤S430,对所述第二联邦学习模型进行测试,若所述第二联邦学习模型的测试结果未满足预设的整合结束条件,则返回获取动态的第二权重系数这一步骤,直至所述第二联邦学习模型的测试结果满足所述整合结束条件。
可以理解的是,通过获取动态的第二权重系数,之后调整学习新任务的中间模型在第二联邦学习模型中的占比大小,并对整合得到的第二联邦学习模型进行测试,若测试结果未满足预设的整合结束条件,则返回获取动态的第二权重系数这一步骤,直至测试结果满足整合结束条件,也就是不断调整第二权重系数,直至基于第二权重系数整合第一联邦学习模型和中间模型得到的第二联邦学习模型满足整合结束条件。
需要说明的是,整合结束条件可以是本轮整合的测试结果优于上一轮,也可以是测试结果达到预设阈值。
在一些实施例中,通过以下公式对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理:
Figure BDA0003773370390000121
其中,所述r2为第二联邦学习模型的模型参数,所述α为第二权重系数,且α∈[0,1],所述wt+1为中间模型的模型参数。
在一些实施例中,在所述对所述第二联邦学习模型进行测试之前,所述方法还包括:
获取测试样本数据集。
请参见图5,图5示出了本申请实施例提供的一种基于持续学习的联邦学习方法的流程示意图,如图5所示,所述对所述第二联邦学习模型进行测试,包括但不限于步骤S510至S520。
步骤S510,将所述测试样本数据集输入到所述第二联邦学习模型,以通过所述第二联邦学习模型得到所述测试样本数据集中每个测试样本数据对应的第三分类预测值。
步骤S520,确定所述测试样本数据集中每个测试样本数据的真实分类标签,根据所述测试样本数据对应的第三分类预测值以及真实分类标签,确定并记录所述第二联邦学习模型的评价指标。
其中,所述评价指标包括以下至少之一:戴斯相似性系数、交并比系数或准确率。
可以理解的是,获取测试样本数据集,可以由服务器搜集获取公开非隐私的样本数据集作为测试样本数据集。具体的,测试样本数据集中包含有与第一样本数据集和第二样本数据集对应的任务类型相匹配的样本数据,因此,通过该测试样本数据集即可在新任务和旧任务两个层面上,对第二联邦学习模型进行测试,以确定整合得到的第二联邦学习模型在新旧任务上的模型性能表现。
应了解,服务器获取测试样本数据集也可以通过接收客户端上传的测试样本数据集,同样,该测试样本数据集中亦包含有与第一样本数据集和第二样本数据集对应的任务类型相匹配的样本数据。
可以理解的是,在每轮整合过程中,确定并记录第二联邦学习模型的评价指标,若本轮整合得到的第二联邦学习模型对应的评价指标优于上一轮,则返回获取动态的第二权重系数这一步骤,继续调整第二权重系数对第一联邦学习模型和中间模型的模型参数进行加权求和处理,直至整合得到的第二联邦学习模型对应的评价指标低于上一轮,停止整合,将评价指标最优的模型作为最终的第二联邦学习模型。
本申请提出一种基于持续学习的联邦学习方法,所述方法应用于联邦学习系统,所述系统包括服务器、多个客户端,服务器分别与多个客户端通信连接,所述服务器和所述客户端均存储有基于第一样本数据集训练得到的第一联邦学习模型,所述方法通过客户端获取与第一样本数据集用于第一联邦学习模型的不同学习任务的第二样本数据集,并从第一样本数据集中提取多个样本数据作为辅助样本数据,并通过客户端将第一联邦学习模型作为本地模型,并以最小化本地模型在学习第二样本数据集以及多个辅助样本数据时产生的损失为训练目标,基于第二样本数据集以及多个辅助样本数据对本地模型进行训练,之后通过服务器接收多个客户端上传的训练好的本地模型,并对第一联邦学习模型和多个本地模型进行整合处理,得到第二联邦学习模型。本申请实施例在本地模型学习新任务的同时,融入学习多个旧样本时产生的损失来修正模型梯度,以保护模型学习到的旧知识,缓解联邦学习中模型的灾难性遗忘。
请参见图7,本申请实施例还提供了一种基于持续学习的联邦学习系统100,所述基于持续学习的联邦学习系统100包括服务器120、多个客户端110,服务器120分别与多个客户端110通信连接,服务器120和客户端110均存储有基于第一样本数据集训练得到的第一联邦学习模型;其中,
所述客户端110包括:
获取模块111,用于获取第二样本数据集,所述第一样本数据集和所述第二样本数据集用于所述第一联邦学习模型的不同学习任务。
提取模块112,用于从所述第一样本数据集中提取多个样本数据作为辅助样本数据。
训练模块113,用于将所述第一联邦学习模型作为本地模型,以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练。
发送模块114,用于将训练后的所述本地模型上传至所述服务器。
所述服务器120包括:
接收模块121,用于接收多个所述客户端上传的所述本地模型。
整合模块122,用于对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型。
需要说明的是,上述装置的模块之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。
请参见图8,图8示出本申请实施例提供的一种电子设备的硬件结构,电子设备包括:
处理器210,可以采用通用的CPU(Central Processing Unit,中央处理器)、微处理器、应用专用集合成电路(Application Specific Integrated Circuit,ASIC)、或者一个或多个集合成电路等方式实现,用于执行相关计算机程序,以实现本申请实施例所提供的技术方案;
存储器220,可以采用只读存储器(Read Only Memory,ROM)、静态存储设备、动态存储设备或者随机存取存储器(Random Access Memory,RAM)等形式实现。存储器220可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器220中,并由处理器210来调用执行本申请实施例的基于持续学习的联邦学习方法;
输入/输出接口230,用于实现信息输入及输出;
通信接口240,用于实现本设备与其他设备的通信交互,可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信;和总线250,在设备的每个组件(例如处理器210、存储器220、输入/输出接口230和通信接口240)之间传输信息;
其中处理器210、存储器220、输入/输出接口230和通信接口240通过总线250实现彼此之间在设备内部的通信连接。
本申请实施例还提供了一种存储介质,存储介质为计算机可读存储介质,用于计算机可读存储,存储介质存储有一个或者多个计算机程序,一个或者多个计算机程序可被一个或者多个处理器执行,以实现上述基于持续学习的联邦学习方法。
存储器作为一种计算机可读存储介质,可用于存储软件程序以及计算机可执行程序。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施方式中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至该处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
本申请实施例描述的实施例是为了更加清楚的说明本申请实施例的技术方案,并不构成对于本申请实施例提供的技术方案的限定,本领域技术人员可知,随着技术的演变和新应用场景的出现,本申请实施例提供的技术方案对于类似的技术问题,同样适用。
以上所描述的装置实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。
本领域普通技术人员可以理解,上文中所公开方法中的全部或某些步骤、系统、设备中的功能模块/单元可以被实施为软件、固件、硬件及其适当的组合。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本申请的说明书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
应当理解,在本申请中,“至少一个(项)”是指一个或者多个,“多个”是指两个或两个以上。“和/或”,用于描述关联对象的关联关系,表示可以存在三种关系,例如,“A和/或B”可以表示:只存在A,只存在B以及同时存在A和B三种情况,其中A,B可以是单数或者复数。字符“/”一般表示前后关联对象是一种“或”的关系。“以下至少一项(个)”或其类似表达,是指这些项中的任意组合,包括单项(个)或复数项(个)的任意组合。例如,a,b或c中的至少一项(个),可以表示:a,b,c,“a和b”,“a和c”,“b和c”,或“a和b和c”,其中a,b,c可以是单个,也可以是多个。
在本申请所提供的几个实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,上述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集合成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
上述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请每个实施例中的各功能单元可以集合成在一个处理单元中,也可以是每个单元单独物理存在,也可以两个或两个以上单元集合成在一个单元中。上述集合成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
集合成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括多指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请每个实施例的方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-Only Memory,简称ROM)、随机存取存储器(Random Access Memory,简称RAM)、磁碟或者光盘等各种可以存储程序的介质。
以上参照附图说明了本申请实施例的优选实施例,并非因此局限本申请实施例的权利范围。本领域技术人员不脱离本申请实施例的范围和实质内所作的任何修改、等同替换和改进,均应在本申请实施例的权利范围之内。

Claims (10)

1.一种基于持续学习的联邦学习方法,其特征在于,所述方法应用于联邦学习系统,所述系统包括服务器、多个客户端,所述服务器分别与多个所述客户端通信连接,所述服务器和所述客户端均存储有基于第一样本数据集训练得到的第一联邦学习模型,所述方法包括:
通过所述客户端获取第二样本数据集,所述第二样本数据集和所述第一样本数据集分别用于所述第一联邦学习模型不同的学习任务;
通过所述客户端从所述第一样本数据集中提取多个样本数据作为辅助样本数据;
通过所述客户端将所述第一联邦学习模型作为本地模型,以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练;
通过所述客户端将训练后的所述本地模型上传至所述服务器;
通过所述服务器接收多个所述客户端上传的所述本地模型,对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型。
2.根据权利要求1所述的基于持续学习的联邦学习方法,其特征在于,在所述通过所述客户端从所述第一样本数据集中提取多个样本数据作为辅助样本数据之后,所述方法还包括:
通过所述客户端将多个所述辅助样本数据输入至所述第一联邦学习模型,以通过所述第一联邦学习模型得到多个所述辅助样本数据对应的第一分类预测值;
所述以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练,包括:
将所述第二样本数据集以及多个所述辅助样本数据输入至所述本地模型,以通过所述本地模型得到所述第二样本数据集中每个样本数据以及多个所述辅助样本数据对应的第二分类预测值;
确定所述第二样本数据集中每个样本数据的真实分类标签,以最小化所述第二样本数据集中每个样本数据的第二分类预测值以及真实分类标签之间的差异为训练目标,确定第一损失函数;
以最小化每个所述辅助样本数据的第一分类预测值和第二分类预测值之间的差异为训练目标,确定第二损失函数;
基于所述第一损失函数和所述第二损失函数对所述本地模型进行训练。
3.根据权利要求2所述的基于持续学习的联邦学习方法,其特征在于,所述第一损失函数通过以下公式确定:
Figure FDA0003773370380000021
其中,所述LC为第一损失函数,所述M表示所述第二样本数据集中分类的类别数量,所述yC为独热编码向量,若所述样本数据的真实分类标签与分类C相同,则yC取1,否则取0,所述pC表征所述样本数据属于分类C的第二分类预测值。
4.根据权利要求2所述的基于持续学习的联邦学习方法,其特征在于,所述第二损失函数通过以下公式确定:
Figure FDA0003773370380000022
其中,所述LMSE为第二损失函数,所述n为所述辅助样本数据的样本数量,所述y′为所述第二分类预测值,所述y为第一分类预测值。
5.根据权利要求1所述的基于持续学习的联邦学习方法,其特征在于,在所述对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型之前,所述方法还包括:
获取每个所述客户端中所述第二样本数据集的样本数量;
所述对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型,包括:
根据每个所述客户端中所述第二样本数据集的样本数量,确定每个所述本地模型的第一权重系数;
根据每个所述本地模型的第一权重系数,对多个所述本地模型的模型参数进行加权平均处理,得到中间模型;
对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理,得到第二联邦学习模型。
6.根据权利要求5所述的基于持续学习的联邦学习方法,其特征在于,所述对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理,得到第二联邦学习模型,包括:
获取动态的第二权重系数;
根据所述第二权重系数,对所述第一联邦学习模型的模型参数和所述中间模型的模型参数进行加权平均处理,得到第二联邦学习模型;
对所述第二联邦学习模型进行测试,若所述第二联邦学习模型的测试结果未满足预设的整合结束条件,则返回获取动态的第二权重系数这一步骤,直至所述第二联邦学习模型的测试结果满足所述整合结束条件。
7.根据权利要求6所述的基于持续学习的联邦学习方法,其特征在于,在所述对所述第二联邦学习模型进行测试之前,所述方法还包括:
获取测试样本数据集;
所述对所述第二联邦学习模型进行测试,包括:
将所述测试样本数据集输入到所述第二联邦学习模型,以通过所述第二联邦学习模型得到所述测试样本数据集中每个测试样本数据对应的第三分类预测值;
确定所述测试样本数据集中每个测试样本数据的真实分类标签,根据所述测试样本数据对应的第三分类预测值以及真实分类标签,确定并记录所述第二联邦学习模型的评价指标;
其中,所述评价指标包括以下至少之一:戴斯相似性系数、交并比系数或准确率。
8.一种基于持续学习的联邦学习系统,其特征在于,所述系统包括服务器、多个客户端,所述服务器分别与多个所述客户端通信连接,所述服务器和所述客户端均存储有基于第一样本数据集训练得到的第一联邦学习模型;其中,
所述客户端包括:
获取模块,用于获取第二样本数据集,所述第一样本数据集和所述第二样本数据集分别用于所述第一联邦学习模型不同的学习任务;
提取模块,用于从所述第一样本数据集中提取多个样本数据作为辅助样本数据;
训练模块,用于将所述第一联邦学习模型作为本地模型,以最小化所述本地模型在学习所述第二样本数据集以及多个所述辅助样本数据时产生的损失为训练目标,基于所述第二样本数据集以及多个所述辅助样本数据对所述本地模型进行训练;
发送模块,用于将训练后的所述本地模型上传至所述服务器;
所述服务器包括:
接收模块,用于接收多个所述客户端上传的所述本地模型;
整合模块,用于对所述第一联邦学习模型和多个所述本地模型进行整合处理,得到第二联邦学习模型。
9.一种电子设备,其特征在于,包括:
至少一个处理器;以及,
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行如权利要求1至7中任一项所述的基于持续学习的联邦学习方法。
10.一种计算机可读存储介质,存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的基于持续学习的联邦学习方法。
CN202210908742.9A 2022-07-29 2022-07-29 基于持续学习的联邦学习方法以及相关设备 Pending CN115222061A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210908742.9A CN115222061A (zh) 2022-07-29 2022-07-29 基于持续学习的联邦学习方法以及相关设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210908742.9A CN115222061A (zh) 2022-07-29 2022-07-29 基于持续学习的联邦学习方法以及相关设备

Publications (1)

Publication Number Publication Date
CN115222061A true CN115222061A (zh) 2022-10-21

Family

ID=83613414

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210908742.9A Pending CN115222061A (zh) 2022-07-29 2022-07-29 基于持续学习的联邦学习方法以及相关设备

Country Status (1)

Country Link
CN (1) CN115222061A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116029371A (zh) * 2023-03-27 2023-04-28 北京邮电大学 基于预训练的联邦学习工作流构建方法及相关设备
CN116796860A (zh) * 2023-08-24 2023-09-22 腾讯科技(深圳)有限公司 联邦学习方法、装置、电子设备及存储介质

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116029371A (zh) * 2023-03-27 2023-04-28 北京邮电大学 基于预训练的联邦学习工作流构建方法及相关设备
CN116029371B (zh) * 2023-03-27 2023-06-06 北京邮电大学 基于预训练的联邦学习工作流构建方法及相关设备
CN116796860A (zh) * 2023-08-24 2023-09-22 腾讯科技(深圳)有限公司 联邦学习方法、装置、电子设备及存储介质
CN116796860B (zh) * 2023-08-24 2023-12-12 腾讯科技(深圳)有限公司 联邦学习方法、装置、电子设备及存储介质

Similar Documents

Publication Publication Date Title
CN108229478B (zh) 图像语义分割及训练方法和装置、电子设备、存储介质和程序
CN111860573B (zh) 模型训练方法、图像类别检测方法、装置和电子设备
CN112434721B (zh) 一种基于小样本学习的图像分类方法、系统、存储介质及终端
CA3066029A1 (en) Image feature acquisition
CN112183577A (zh) 一种半监督学习模型的训练方法、图像处理方法及设备
CN110070029B (zh) 一种步态识别方法及装置
CN115222061A (zh) 基于持续学习的联邦学习方法以及相关设备
CN108229522B (zh) 神经网络的训练方法、属性检测方法、装置及电子设备
CN110598603A (zh) 人脸识别模型获取方法、装置、设备和介质
CN113723288B (zh) 基于多模态混合模型的业务数据处理方法及装置
CN111931859B (zh) 一种多标签图像识别方法和装置
CN110597965B (zh) 文章的情感极性分析方法、装置、电子设备及存储介质
CN113657087B (zh) 信息的匹配方法及装置
CN113313215B (zh) 图像数据处理方法、装置、计算机设备和存储介质
CN111931809A (zh) 数据的处理方法、装置、存储介质及电子设备
CN117036843A (zh) 目标检测模型训练方法、目标检测方法和装置
CN111444850A (zh) 一种图片检测的方法和相关装置
CN114155397A (zh) 一种小样本图像分类方法及系统
CN114913330B (zh) 点云部件分割方法、装置、电子设备与存储介质
CN111611917A (zh) 模型训练方法、特征点检测方法、装置、设备及存储介质
CN115759293A (zh) 模型训练方法、图像检索方法、装置及电子设备
CN116958724A (zh) 一种产品分类模型的训练方法和相关装置
CN114973271A (zh) 一种文本信息提取方法、提取系统、电子设备及存储介质
CN115205546A (zh) 模型训练方法和装置、电子设备、存储介质
CN111582404B (zh) 内容分类方法、装置及可读存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination