CN115018019B - 基于联邦学习的模型训练方法及系统、存储介质 - Google Patents

基于联邦学习的模型训练方法及系统、存储介质 Download PDF

Info

Publication number
CN115018019B
CN115018019B CN202210939615.5A CN202210939615A CN115018019B CN 115018019 B CN115018019 B CN 115018019B CN 202210939615 A CN202210939615 A CN 202210939615A CN 115018019 B CN115018019 B CN 115018019B
Authority
CN
China
Prior art keywords
model
round
current round
local
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210939615.5A
Other languages
English (en)
Other versions
CN115018019A (zh
Inventor
谢翀
兰鹏
陈永红
赵豫陕
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Qianhai Huanrong Lianyi Information Technology Service Co Ltd
Original Assignee
Shenzhen Qianhai Huanrong Lianyi Information Technology Service Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Qianhai Huanrong Lianyi Information Technology Service Co Ltd filed Critical Shenzhen Qianhai Huanrong Lianyi Information Technology Service Co Ltd
Priority to CN202210939615.5A priority Critical patent/CN115018019B/zh
Publication of CN115018019A publication Critical patent/CN115018019A/zh
Application granted granted Critical
Publication of CN115018019B publication Critical patent/CN115018019B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning

Abstract

本申请公开了一种基于联邦学习的模型训练方法及系统、存储介质、计算机设备,该方法包括:服务器在客户端集合中采样多个本轮客户端,将本轮初始模型参数发送至本轮客户端中,本轮客户端按本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,并将本轮目标模型参数返回至服务器中;服务器对多个本轮客户端各自返回的本轮目标模型参数进行参数聚合,得到本轮聚合参数;当本轮聚合参数未达到服务器的采样条件时,将本轮聚合参数作为下轮初始模型参数;当本轮聚合参数达到服务器的采样条件时,将本轮聚合参数发送至客户端集合内的每个客户端中,每个客户端按本轮聚合参数配置本地模型后进行最后一轮模型训练。

Description

基于联邦学习的模型训练方法及系统、存储介质
技术领域
本申请涉及模型训练技术领域,尤其是涉及到一种基于联邦学习的模型训练方法及装置、存储介质、计算机设备。
背景技术
由于企业间的特殊关系,企业数据的隐私性要求极高,数据差异大且无法互通。科技公司针对某一企业的数据所开发的相同功能的算法模型往往无法快速应用在另外的企业,这不仅导致了科技公司的开发成本高、开发效率低,而且针对每个企业开发的模型泛化性差。
发明内容
有鉴于此,本申请提供了一种基于联邦学习的模型训练方法及装置、存储介质、计算机设备,训练过程中服务器不会接触任何客户端的训练样本,保障了各方的数据隐私需求,并且通过对多个客户端的模型进行协同、统一训练,提升了模型开发效率以及模型泛化性。
根据本申请的一个方面,提供了一种基于联邦学习的模型训练方法,所述方法包括:
服务器在客户端集合中采样多个本轮客户端,将本轮初始模型参数发送至所述本轮客户端中,所述本轮客户端按所述本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,并将所述本轮目标模型参数返回至所述服务器中;
所述服务器对多个所述本轮客户端各自返回的所述本轮目标模型参数进行参数聚合,得到本轮聚合参数;
当所述本轮聚合参数未达到所述服务器的采样条件时,将所述本轮聚合参数作为下轮初始模型参数,并在所述客户端集合中重新采样多个下轮客户端,向所述下轮客户端发送所述下轮初始模型参数,以通过所述下轮客户端进行下个轮次的模型训练;
当所述本轮聚合参数达到所述服务器的采样条件时,将所述本轮聚合参数发送至所述客户端集合内的每个客户端中,每个所述客户端按所述本轮聚合参数配置本地模型后进行最后一轮模型训练。
可选地,所述服务器在客户端集合中采样多个本轮客户端之前,所述方法还包括:
初始化元模型参数,将所述元模型的初始化参数作为第一轮初始模型参数,其中,所述客户端集合中每个所述客户端对应的本地模型的模型结构均与所述元模型的模型结构相同。
可选地,所述将所述本轮目标模型参数返回至所述服务器中,包括:
所述本轮客户端将所述本轮目标模型参数以及本轮训练样本量返回至所述服务器中;
所述服务器对多个所述本轮客户端各自返回的所述本轮目标模型参数进行参数聚合,得到本轮聚合参数,包括:
所述服务器根据每个所述本轮客户端对应的所述本轮训练样本量占本轮训练样本总数的比例,确定每个所述本轮客户端对应的所述本轮目标模型参数的参数权重,并按所述参数权重对所述本轮目标模型参数进行加权求和,得到所述本轮聚合参数。
可选地,所述元模型为分类模型,用于预测目标输入量属于不同类别的概率;所述本轮客户端按所述本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,包括:
若所述本轮客户端为首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;
所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数作为所述本地模型的目标损失函数,对所述本地模型进行训练。
可选地,所述方法还包括:
若所述本轮客户端为非首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;
以所述本轮客户端中当前的本地模型作为参照模型,依据所述参照模型对训练样本的预测数据以及本轮训练中的本地模型对训练样本的预测数据之间的KL散度,确定所述本地模型的第二损失函数;
所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数与所述第二损失函数之和作为所述本地模型的目标损失函数,对所述本地模型进行训练。
可选地,所述依据所述参照模型对训练样本的预测数据以及本轮训练中的本地模型对训练样本的预测数据之间的KL散度,确定所述本地模型的第二损失函数,包括:
基于所述参照模型对训练样本的中间层输出数据以及本轮训练中的本地模型对训练样本的中间层输出数据之间的KL散度,确定第三损失函数;
基于所述参照模型对训练样本属于不同类别的预测概率以及本轮训练中的本地模型对训练样本属于不同类别的预测概率之间的KL散度,确定第四损失函数;
将所述第三损失函数和所述第四损失函数的加权求和结果作为所述第二损失函数。
可选地,所述采样条件为采样轮次阈值;所述元模型为图像分类模型。
根据本申请的另一方面,提供了一种基于联邦学习的模型训练系统,所述系统包括:
服务器和多个客户端,其中,多个客户端构成客户端集合;
所述服务器,用于在客户端集合中采样多个本轮客户端,将本轮初始模型参数发送至所述本轮客户端中;
所述本轮客户端,用于按所述本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,并将所述本轮目标模型参数返回至所述服务器中;
所述服务器,还用于对多个所述本轮客户端各自返回的所述本轮目标模型参数进行参数聚合,得到本轮聚合参数;
所述服务器,还用于当所述本轮聚合参数未达到所述服务器的采样条件时,将所述本轮聚合参数作为下轮初始模型参数,并在所述客户端集合中重新采样多个下轮客户端,向所述下轮客户端发送所述下轮初始模型参数,以通过所述下轮客户端进行下个轮次的模型训练;
所述服务器,还用于当所述本轮聚合参数达到所述服务器的采样条件时,将所述本轮聚合参数发送至所述客户端集合内的每个客户端中,每个所述客户端按所述本轮聚合参数配置本地模型后进行最后一轮模型训练。
可选地,所述服务器,还用于在客户端集合中采样多个本轮客户端之前,初始化元模型参数,将所述元模型的初始化参数作为第一轮初始模型参数,其中,所述客户端集合中每个所述客户端对应的本地模型的模型结构均与所述元模型的模型结构相同。
可选地,所述本轮客户端,还用于所述本轮客户端将所述本轮目标模型参数以及本轮训练样本量返回至所述服务器中;
所述服务器,还用于根据每个所述本轮客户端对应的所述本轮训练样本量占本轮训练样本总数的比例,确定每个所述本轮客户端对应的所述本轮目标模型参数的参数权重,并按所述参数权重对所述本轮目标模型参数进行加权求和,得到所述本轮聚合参数。
可选地,所述元模型为分类模型,用于预测目标输入量属于不同类别的概率;所述本轮客户端,还用于:
若所述本轮客户端为首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;
所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数作为所述本地模型的目标损失函数,对所述本地模型进行训练。
可选地,所述本轮客户端,还用于:
若所述本轮客户端为非首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;
以所述本轮客户端中当前的本地模型作为参照模型,依据所述参照模型对训练样本的预测数据以及本轮训练中的本地模型对训练样本的预测数据之间的KL散度,确定所述本地模型的第二损失函数;
所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数与所述第二损失函数之和作为所述本地模型的目标损失函数,对所述本地模型进行训练。
可选地,所述本轮客户端,还用于:
基于所述参照模型对训练样本的中间层输出数据以及本轮训练中的本地模型对训练样本的中间层输出数据之间的KL散度,确定第三损失函数;
基于所述参照模型对训练样本属于不同类别的预测概率以及本轮训练中的本地模型对训练样本属于不同类别的预测概率之间的KL散度,确定第四损失函数;
将所述第三损失函数和所述第四损失函数的加权求和结果作为所述第二损失函数。
可选地,所述采样条件为采样轮次阈值;所述元模型为图像分类模型。
依据本申请又一个方面,提供了一种存储介质,其上存储有计算机程序,所述程序被处理器执行时实现上述基于联邦学习的模型训练方法。
依据本申请再一个方面,提供了一种计算机设备,包括存储介质、处理器及存储在存储介质上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述基于联邦学习的模型训练方法。
借由上述技术方案,本申请提供的一种基于联邦学习的模型训练方法及装置、存储介质、计算机设备,通过服务器和多个客户端协同完成模型训练,客户端利用本地训练样本数据集进行模型训练,完成后将模型参数发送给服务器,服务器对收到的模型参数进行聚合后分发给参与的客户端。训练过程中服务器不会接触任何客户端的训练样本,保障了各方的数据隐私需求,并且通过对多个客户端的模型进行协同、统一训练,提升了模型开发效率以及模型泛化性。
上述说明仅是本申请技术方案的概述,为了能够更清楚了解本申请的技术手段,而可依照说明书的内容予以实施,并且为了让本申请的上述和其它目的、特征和优点能够更明显易懂,以下特举本申请的具体实施方式。
附图说明
此处所说明的附图用来提供对本申请的进一步理解,构成本申请的一部分,本申请的示意性实施例及其说明用于解释本申请,并不构成对本申请的不当限定。在附图中:
图1示出了本申请实施例提供的一种基于联邦学习的模型训练方法的流程示意图;
图2示出了本申请实施例提供的另一种基于联邦学习的模型训练方法的流程示意图;
图3示出了本申请实施例提供的一种基于联邦学习的模型训练系统的结构示意图。
具体实施方式
下文中将参考附图并结合实施例来详细说明本申请。需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。
在本实施例中提供了一种基于联邦学习的模型训练方法,如图1所示,该方法包括:
步骤101,服务器在客户端集合中采样多个本轮客户端,将本轮初始模型参数发送至所述本轮客户端中,所述本轮客户端按所述本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,并将所述本轮目标模型参数返回至所述服务器中;
本申请实施例主要由两阶段的流程组成,分别为服务器流程和客户端流程,设置一个中心服务器和多个参与客户端,中心服务器主要负责收集所有参与客户端发送的模型参数信息,再对收到的模型参数信息进行聚合后分发给参与的客户端。这一过程服务器不会接触任何客户端的数据明文,保障了各方的数据隐私需求。客户端主要负责利用本地数据集进行模型训练,完成后将模型参数或部分统计信息发送给服务器。
本申请实施例提供的模型训练过程包括多轮训练,在其中任意一轮训练过程中,服务器在客户端集合中采样一组客户端,即本轮参与训练的本轮客户端,将预先确定的本轮初始模型参数发送至各本轮客户端中,对于任意一个本轮客户端来说,客户端接收到本轮初始模型参数后,对客户端本地模型进行参数赋值,利用本地数据对赋值后的本地模型进行训练,并将训练得到的本地模型参数作为该客户端对应的本轮目标模型参数返回到服务器中。
步骤102,所述服务器对多个所述本轮客户端各自返回的所述本轮目标模型参数进行参数聚合,得到本轮聚合参数;
在该实施例中,多个本轮客户端均完成本轮的模型训练后,服务器将接收到多组本轮目标模型参数,服务器对多组本轮目标模型参数进行参数聚合,将本轮多个客户端的训练结果聚合为一组参数,即本轮聚合参数。具体地,可以将各本轮目标模型参数的平均值作为本轮聚合参数。
在本申请实施例中,可选地,步骤101中“将所述本轮目标模型参数返回至所述服务器中”具体包括:所述本轮客户端将所述本轮目标模型参数以及本轮训练样本量返回至所述服务器中;
相应地,步骤102具体可以包括:所述服务器根据每个所述本轮客户端对应的所述本轮训练样本量占本轮训练样本总数的比例,确定每个所述本轮客户端对应的所述本轮目标模型参数的参数权重,并按所述参数权重对所述本轮目标模型参数进行加权求和,得到所述本轮聚合参数。
在该实施例中,客户端完成一轮模型训练后,除了将训练得到的目标模型参数发送到服务器外,还可以将本轮训练的训练样本量发送到服务器中,服务器进行参数聚合时,将各客户端的训练样本量占本轮训练样本总数的比例,确定为各客户端对应的目标模型参数的参数权重,并按参数权重对目标模型参数进行加权求和,得到本轮聚合参数。
步骤103,当所述本轮聚合参数未达到所述服务器的采样条件时,将所述本轮聚合参数作为下轮初始模型参数,并在所述客户端集合中重新采样多个下轮客户端,向所述下轮客户端发送所述下轮初始模型参数,以通过所述下轮客户端进行下个轮次的模型训练;
在该实施例中,得到本轮聚合参数后,判断是否已经达到了服务器的采样条件。其中采样条件可以为采样轮次阈值,服务器对客户端的采样轮次达到阈值认为满足采样条件,否则认为未达到采样条件。采样条件也可以为聚合参数偏差阈值,如果本轮聚合参数与上一轮聚合参数的偏差小于该阈值,认为满足采样条件,否则认为未达到采样条件。如果还没有达到服务器的采样条件,那么将本轮聚合参数作为下一轮模型训练的初始模型参数,重新在客户端集合中采样一组客户端,重复上述的训练过程完成下一轮的模型训练。
在本申请实施例中,可选地,步骤101之前还包括:初始化元模型参数,将所述元模型的初始化参数作为第一轮初始模型参数,其中,所述客户端集合中每个所述客户端对应的本地模型的模型结构均与所述元模型的模型结构相同。
在上述实施例中,第一轮初始模型参数通过服务器初始化元模型参数得到,客户端集合中各客户端的本地模型的模型结构均与元模型的模型结构相同。
步骤104,当所述本轮聚合参数达到所述服务器的采样条件时,将所述本轮聚合参数发送至所述客户端集合内的每个客户端中,每个所述客户端按所述本轮聚合参数配置本地模型后进行最后一轮模型训练。
在该实施例中,当达到服务器的采样条件时,将最后一轮的聚合参数发送到客户端集合内的每个客户端中,各客户端接收到该参数后,对本地模型进行参数配置,并完成最后一轮模型训练得到最终的本地模型,完成整个训练过程。
通过应用本实施例的技术方案,通过服务器和多个客户端协同完成模型训练,客户端利用本地训练样本数据集进行模型训练,完成后将模型参数发送给服务器,服务器对收到的模型参数进行聚合后分发给参与的客户端。训练过程中服务器不会接触任何客户端的训练样本,保障了各方的数据隐私需求,并且通过对多个客户端的模型进行协同、统一训练,提升了模型开发效率以及模型泛化性。
在本申请实施例中,对于客户端的模型训练流程,可选地,通过以下方式确定模型的损失函数并进行模型训练:
若所述本轮客户端为首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数作为所述本地模型的目标损失函数,对所述本地模型进行训练。
若所述本轮客户端为非首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;以所述本轮客户端中当前的本地模型作为参照模型,依据所述参照模型对训练样本的预测数据以及本轮训练中的本地模型对训练样本的预测数据之间的KL散度,确定所述本地模型的第二损失函数;所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数与所述第二损失函数之和作为所述本地模型的目标损失函数,对所述本地模型进行训练。
在上述实施例中,对于每个客户端来说,如果客户端是第一次被采样到,那么确定模型的目标损失函数为客户端本地训练样本的真实类别(即标签类别)与模型预测出的该类别的类别概率值之间的交叉熵,并使用本地训练样本进行模型训练。
而如果客户端不是第一次被采样到,除了上述第一次本地训练使用的目标损失函数即第一损失函数外,还可以参考之前的训练结果设置第二损失函数。具体地,在进行本轮模型训练之前,可以先复制当前的本地模型(即上一次训练得到的本地模型)作为参考模型,将参照模型对训练样本的预测数据与训练过程中参数变化后的本地模型对训练样本的预测数据之间的KL散度,作为第二损失函数,累加第一损失函数和第二损失函数得到目标损失函数,并使用本地训练样本进行模型训练。
在上述实施例中,可选地,第二损失函数具体通过以下方式确定:基于所述参照模型对训练样本的中间层输出数据以及本轮训练中的本地模型对训练样本的中间层输出数据之间的KL散度,确定第三损失函数;基于所述参照模型对训练样本属于不同类别的预测概率以及本轮训练中的本地模型对训练样本属于不同类别的预测概率之间的KL散度,确定第四损失函数;将所述第三损失函数和所述第四损失函数的加权求和结果作为所述第二损失函数。
在该实施例中,模型训练过程中会对训练样本输出隐层向量(即中间层输出数据)和归一化概率值(即不同类别的预测概率)。同时用参考模型对相同的训练样本进行一次预测操作,同样输出隐层向量和归一化概率值。将相同的训练样本分别输入到参考模型和训练中的本地模型中,获取参考模型的隐层向量和归一化概率值、以及本地模型的隐层向量和归一化概率值,计算参考模型和本地模型各自的隐层向量之间的KL散度,得到第三损失函数,计算参考模型和本地模型各自的归一化概率值之间的KL散度,得到第四损失函数,对第三损失函数和第四损失函数加权求和得到第二损失函数。
在一个具体的实施例中,如图2所示,服务端流程包括:
a、服务器初始化元模型参数;
b、服务器随机采样一组客户端,将初始的元模型参数发送给这组被采样客户端;
c、跳转至客户端流程,等待所有被采样客户端本地训练结束;
d、接收这组被采样客户端训练后的模型参数,计算每个被采样客户端的训练数据量占这轮训练中总训练数据量的百分比作为权重对模型参数进行加权求和;
e、求和的结果作为下一轮训练的初始模型参数,跳转至b步骤进行下一轮采样训练,达到最大通信采样轮次后进入f步骤;
f、服务器对所有参与客户端进行采样,将此时的元模型参数发送给每个参与客户端,各参与客户端进行参数配置后完成最后一轮模型训练,得到各自的本地模型,结束整个训练流程。
客户端流程:
a、接收服务器发送的初始模型参数,配置为本地模型的模型参数;
b、如果本客户端是第一次被采样到,跳转至c步骤,否则跳转至d步骤;
c、使用本地数据进行训练,损失函数是客户端本地训练样本的真实类别和模型预测的类别概率值之间的交叉熵;
d、使用本地数据进行训练,训练完成后跳转至步骤e,损失函数包括如下两项:
1)和c步骤相同的交叉熵;
2)模型训练过程中会对训练样本输出隐层向量和归一化概率值。同时将上一轮被采样训练之后的模型作为参考模型,用参考模型对本地训练样本进行一次预测操作,同样输出隐层向量和归一化概率值。分别计算本地训练得到的和参考模型预测的隐层向量及归一化概率值之间的KL散度作为额外的损失,这部分损失函数=k1*隐层向量KL散度+k2*归一化概率值KL散度。
e、在本地保存训练完成后的模型参数,覆盖前一次被采样时保存的参数,同时将该模型参数以及本轮的训练样本量发送给服务器端,客户端流程结束。
进一步的,作为图1方法的具体实现,本申请实施例提供了一种基于联邦学习的模型训练系统,如图3所示,该系统包括:
服务器和多个客户端,其中,多个客户端构成客户端集合;
所述服务器,用于在客户端集合中采样多个本轮客户端,将本轮初始模型参数发送至所述本轮客户端中;
所述本轮客户端,用于按所述本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,并将所述本轮目标模型参数返回至所述服务器中;
所述服务器,还用于对多个所述本轮客户端各自返回的所述本轮目标模型参数进行参数聚合,得到本轮聚合参数;
所述服务器,还用于当所述本轮聚合参数未达到所述服务器的采样条件时,将所述本轮聚合参数作为下轮初始模型参数,并在所述客户端集合中重新采样多个下轮客户端,向所述下轮客户端发送所述下轮初始模型参数,以通过所述下轮客户端进行下个轮次的模型训练;
所述服务器,还用于当所述本轮聚合参数达到所述服务器的采样条件时,将所述本轮聚合参数发送至所述客户端集合内的每个客户端中,每个所述客户端按所述本轮聚合参数配置本地模型后进行最后一轮模型训练。
可选地,所述服务器,还用于在客户端集合中采样多个本轮客户端之前,初始化元模型参数,将所述元模型的初始化参数作为第一轮初始模型参数,其中,所述客户端集合中每个所述客户端对应的本地模型的模型结构均与所述元模型的模型结构相同。
可选地,所述本轮客户端,还用于所述本轮客户端将所述本轮目标模型参数以及本轮训练样本量返回至所述服务器中;
所述服务器,还用于根据每个所述本轮客户端对应的所述本轮训练样本量占本轮训练样本总数的比例,确定每个所述本轮客户端对应的所述本轮目标模型参数的参数权重,并按所述参数权重对所述本轮目标模型参数进行加权求和,得到所述本轮聚合参数。
可选地,所述元模型为分类模型,用于预测目标输入量属于不同类别的概率;所述本轮客户端,还用于:
若所述本轮客户端为首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;
所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数作为所述本地模型的目标损失函数,对所述本地模型进行训练。
可选地,所述本轮客户端,还用于:
若所述本轮客户端为非首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;
以所述本轮客户端中当前的本地模型作为参照模型,依据所述参照模型对训练样本的预测数据以及本轮训练中的本地模型对训练样本的预测数据之间的KL散度,确定所述本地模型的第二损失函数;
所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数与所述第二损失函数之和作为所述本地模型的目标损失函数,对所述本地模型进行训练。
可选地,所述本轮客户端,还用于:
基于所述参照模型对训练样本的中间层输出数据以及本轮训练中的本地模型对训练样本的中间层输出数据之间的KL散度,确定第三损失函数;
基于所述参照模型对训练样本属于不同类别的预测概率以及本轮训练中的本地模型对训练样本属于不同类别的预测概率之间的KL散度,确定第四损失函数;
将所述第三损失函数和所述第四损失函数的加权求和结果作为所述第二损失函数。
可选地,所述采样条件为采样轮次阈值;所述元模型为图像分类模型。
需要说明的是,本申请实施例提供的一种基于联邦学习的模型训练系统所涉及各功能单元的其他相应描述,可以参考图1至图2方法中的对应描述,在此不再赘述。
基于上述如图1至图2所示方法,相应的,本申请实施例还提供了一种存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述如图1至图2所示的基于联邦学习的模型训练方法。
基于这样的理解,本申请的技术方案可以以软件产品的形式体现出来,该软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等)中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施场景所述的方法。
基于上述如图1至图2所示的方法,以及图3所示的系统实施例,为了实现上述目的,本申请实施例还提供了一种计算机设备,具体可以为个人计算机、服务器、网络设备等,该计算机设备包括存储介质和处理器;存储介质,用于存储计算机程序;处理器,用于执行计算机程序以实现上述如图1至图2所示的基于联邦学习的模型训练方法。
可选地,该计算机设备还可以包括用户接口、网络接口、摄像头、射频(RadioFrequency,RF)电路,传感器、音频电路、WI-FI模块等等。用户接口可以包括显示屏(Display)、输入单元比如键盘(Keyboard)等,可选用户接口还可以包括USB接口、读卡器接口等。网络接口可选的可以包括标准的有线接口、无线接口(如蓝牙接口、WI-FI接口)等。
本领域技术人员可以理解,本实施例提供的一种计算机设备结构并不构成对该计算机设备的限定,可以包括更多或更少的部件,或者组合某些部件,或者不同的部件布置。
存储介质中还可以包括操作系统、网络通信模块。操作系统是管理和保存计算机设备硬件和软件资源的程序,支持信息处理程序以及其它软件和/或程序的运行。网络通信模块用于实现存储介质内部各组件之间的通信,以及与该实体设备中其它硬件和软件之间通信。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到本申请可以借助软件加必要的通用硬件平台的方式来实现,也可以通过硬件实现通过服务器和多个客户端协同完成模型训练,客户端利用本地训练样本数据集进行模型训练,完成后将模型参数发送给服务器,服务器对收到的模型参数进行聚合后分发给参与的客户端。训练过程中服务器不会接触任何客户端的训练样本,保障了各方的数据隐私需求,并且通过对多个客户端的模型进行协同、统一训练,提升了模型开发效率以及模型泛化性。
本领域技术人员可以理解附图只是一个优选实施场景的示意图,附图中的模块或流程并不一定是实施本申请所必须的。本领域技术人员可以理解实施场景中的装置中的模块可以按照实施场景描述进行分布于实施场景的装置中,也可以进行相应变化位于不同于本实施场景的一个或多个装置中。上述实施场景的模块可以合并为一个模块,也可以进一步拆分成多个子模块。
上述本申请序号仅仅为了描述,不代表实施场景的优劣。以上公开的仅为本申请的几个具体实施场景,但是,本申请并非局限于此,任何本领域的技术人员能思之的变化都应落入本申请的保护范围。

Claims (8)

1.一种基于联邦学习的模型训练方法,其特征在于,所述方法包括:
服务器在客户端集合中采样多个本轮客户端,将本轮初始模型参数发送至所述本轮客户端中,所述本轮客户端按所述本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,并将所述本轮目标模型参数返回至所述服务器中;
所述服务器对多个所述本轮客户端各自返回的所述本轮目标模型参数进行参数聚合,得到本轮聚合参数;
当所述本轮聚合参数未达到所述服务器的采样条件时,将所述本轮聚合参数作为下轮初始模型参数,并在所述客户端集合中重新采样多个下轮客户端,向所述下轮客户端发送所述下轮初始模型参数,以通过所述下轮客户端进行下个轮次的模型训练;
当所述本轮聚合参数达到所述服务器的采样条件时,将所述本轮聚合参数发送至所述客户端集合内的每个客户端中,每个所述客户端按所述本轮聚合参数配置本地模型后进行最后一轮模型训练;
其中,所述本轮客户端对本地模型的训练过程包括:
若所述本轮客户端为非首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;
以所述本轮客户端中当前的本地模型作为参照模型,基于所述参照模型对训练样本的中间层输出数据以及本轮训练中的本地模型对训练样本的中间层输出数据之间的KL散度,确定第三损失函数;基于所述参照模型对训练样本属于不同类别的预测概率以及本轮训练中的本地模型对训练样本属于不同类别的预测概率之间的KL散度,确定第四损失函数;将所述第三损失函数和所述第四损失函数的加权求和结果作为第二损失函数;
所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数与所述第二损失函数之和作为所述本地模型的目标损失函数,对所述本地模型进行训练。
2.根据权利要求1所述的方法,其特征在于,所述服务器在客户端集合中采样多个本轮客户端之前,所述方法还包括:
初始化元模型参数,将所述元模型的初始化参数作为第一轮初始模型参数,其中,所述客户端集合中每个所述客户端对应的本地模型的模型结构均与所述元模型的模型结构相同。
3.根据权利要求2所述的方法,其特征在于,所述将所述本轮目标模型参数返回至所述服务器中,包括:
所述本轮客户端将所述本轮目标模型参数以及本轮训练样本量返回至所述服务器中;
所述服务器对多个所述本轮客户端各自返回的所述本轮目标模型参数进行参数聚合,得到本轮聚合参数,包括:
所述服务器根据每个所述本轮客户端对应的所述本轮训练样本量占本轮训练样本总数的比例,确定每个所述本轮客户端对应的所述本轮目标模型参数的参数权重,并按所述参数权重对所述本轮目标模型参数进行加权求和,得到所述本轮聚合参数。
4.根据权利要求2或3所述的方法,其特征在于,所述元模型为分类模型,用于预测目标输入量属于不同类别的概率;所述本轮客户端按所述本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,包括:
若所述本轮客户端为首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;
所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数作为所述本地模型的目标损失函数,对所述本地模型进行训练。
5.根据权利要求2所述的方法,其特征在于,所述采样条件为采样轮次阈值;所述元模型为图像分类模型。
6.一种基于联邦学习的模型训练系统,其特征在于,所述系统包括:
服务器和多个客户端,其中,多个客户端构成客户端集合;
所述服务器,用于在客户端集合中采样多个本轮客户端,将本轮初始模型参数发送至所述本轮客户端中;
所述本轮客户端,用于按所述本轮初始模型参数配置本地模型后进行模型训练,得到训练后的本轮目标模型参数,并将所述本轮目标模型参数返回至所述服务器中;
所述服务器,还用于对多个所述本轮客户端各自返回的所述本轮目标模型参数进行参数聚合,得到本轮聚合参数;
所述服务器,还用于当所述本轮聚合参数未达到所述服务器的采样条件时,将所述本轮聚合参数作为下轮初始模型参数,并在所述客户端集合中重新采样多个下轮客户端,向所述下轮客户端发送所述下轮初始模型参数,以通过所述下轮客户端进行下个轮次的模型训练;
所述服务器,还用于当所述本轮聚合参数达到所述服务器的采样条件时,将所述本轮聚合参数发送至所述客户端集合内的每个客户端中,每个所述客户端按所述本轮聚合参数配置本地模型后进行最后一轮模型训练;
所述本轮客户端,还用于:若所述本轮客户端为非首次被采样,则依据训练样本所属的标签类别与所述本地模型预测出的训练样本属于所述标签类别的概率之间的交叉熵,确定所述本地模型的第一损失函数;以所述本轮客户端中当前的本地模型作为参照模型,基于所述参照模型对训练样本的中间层输出数据以及本轮训练中的本地模型对训练样本的中间层输出数据之间的KL散度,确定第三损失函数;基于所述参照模型对训练样本属于不同类别的预测概率以及本轮训练中的本地模型对训练样本属于不同类别的预测概率之间的KL散度,确定第四损失函数;将所述第三损失函数和所述第四损失函数的加权求和结果作为第二损失函数;所述本轮客户端按所述本轮初始模型参数配置所述本地模型后,将所述第一损失函数与所述第二损失函数之和作为所述本地模型的目标损失函数,对所述本地模型进行训练。
7.一种存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至5中任一项所述基于联邦学习的模型训练方法。
8.一种计算机设备,包括存储介质、处理器及存储在存储介质上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至5中任一项所述基于联邦学习的模型训练方法。
CN202210939615.5A 2022-08-05 2022-08-05 基于联邦学习的模型训练方法及系统、存储介质 Active CN115018019B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210939615.5A CN115018019B (zh) 2022-08-05 2022-08-05 基于联邦学习的模型训练方法及系统、存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210939615.5A CN115018019B (zh) 2022-08-05 2022-08-05 基于联邦学习的模型训练方法及系统、存储介质

Publications (2)

Publication Number Publication Date
CN115018019A CN115018019A (zh) 2022-09-06
CN115018019B true CN115018019B (zh) 2022-11-01

Family

ID=83065901

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210939615.5A Active CN115018019B (zh) 2022-08-05 2022-08-05 基于联邦学习的模型训练方法及系统、存储介质

Country Status (1)

Country Link
CN (1) CN115018019B (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115086399B (zh) * 2022-07-28 2022-12-06 深圳前海环融联易信息科技服务有限公司 基于超网络的联邦学习方法、装置及计算机设备
CN116050548B (zh) * 2023-03-27 2023-07-04 深圳前海环融联易信息科技服务有限公司 一种联邦学习方法、装置及电子设备

Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021190638A1 (zh) * 2020-11-24 2021-09-30 平安科技(深圳)有限公司 基于非均匀分布数据的联邦建模方法及相关设备

Family Cites Families (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112365007A (zh) * 2020-11-11 2021-02-12 深圳前海微众银行股份有限公司 模型参数确定方法、装置、设备及存储介质
CN112446040A (zh) * 2020-11-24 2021-03-05 平安科技(深圳)有限公司 基于选择性梯度更新的联邦建模方法及相关设备
CN113191503B (zh) * 2021-05-20 2023-06-09 清华大学深圳国际研究生院 一种非共享数据的去中心化的分布式学习方法及系统
CN113435125A (zh) * 2021-07-06 2021-09-24 山东大学 一种面向联邦物联网系统的模型训练加速方法与系统
CN114357067A (zh) * 2021-12-15 2022-04-15 华南理工大学 一种针对数据异构性的个性化联邦元学习方法
CN114387580A (zh) * 2022-01-06 2022-04-22 厦门大学 基于联邦学习的模型训练方法及装置

Patent Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021190638A1 (zh) * 2020-11-24 2021-09-30 平安科技(深圳)有限公司 基于非均匀分布数据的联邦建模方法及相关设备

Also Published As

Publication number Publication date
CN115018019A (zh) 2022-09-06

Similar Documents

Publication Publication Date Title
CN115018019B (zh) 基于联邦学习的模型训练方法及系统、存储介质
CN112085172B (zh) 图神经网络的训练方法及装置
CN113516255A (zh) 联邦学习建模优化方法、设备、可读存储介质及程序产品
CN109597965B (zh) 基于深度神经网络的数据处理方法、系统、终端及介质
CN104424507B (zh) 一种回声状态网络的预测方法和预测装置
CN113902473A (zh) 业务预测系统的训练方法及装置
CN112085615A (zh) 图神经网络的训练方法及装置
JP2023525727A (ja) 領域区分方法、装置、電子機器及びコンピュータプログラム
CN111626767B (zh) 资源数据的发放方法、装置及设备
CN110599312A (zh) 基于信用的交互信用评估方法以及装置
US20220351039A1 (en) Federated learning using heterogeneous model types and architectures
CN111815169A (zh) 业务审批参数配置方法及装置
CN112948885A (zh) 实现隐私保护的多方协同更新模型的方法、装置及系统
US20190362197A1 (en) Efficient incident management in large scale computer systems
CN115439192A (zh) 医疗商品信息的推送方法及装置、存储介质、计算机设备
CN114140033A (zh) 一种服务人员的分配方法、装置、电子设备及存储介质
CN114154392A (zh) 基于区块链和联邦学习的模型共建方法、装置及设备
CN112887371B (zh) 边缘计算方法、装置、计算机设备及存储介质
CN111210279B (zh) 一种目标用户预测方法、装置和电子设备
CN111163237B (zh) 呼叫业务流程控制方法和相关装置
CN110942345B (zh) 种子用户的选取方法、装置、设备及存储介质
KR102430775B1 (ko) 비즈니스 모델에 대한 시장성, 사용성 및 기술성 테스트를 제공하기 위한 통합 테스트 플랫폼 시스템 및 그 동작 방법
CN114996434B (zh) 一种信息抽取方法及装置、存储介质、计算机设备
EP4036811A1 (en) Combining compression, partitioning and quantization of dl models for fitment in hardware processors
CN112365189A (zh) 案件分配方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant