CN114580663A - 面向数据非独立同分布场景的联邦学习方法及系统 - Google Patents
面向数据非独立同分布场景的联邦学习方法及系统 Download PDFInfo
- Publication number
- CN114580663A CN114580663A CN202210192242.XA CN202210192242A CN114580663A CN 114580663 A CN114580663 A CN 114580663A CN 202210192242 A CN202210192242 A CN 202210192242A CN 114580663 A CN114580663 A CN 114580663A
- Authority
- CN
- China
- Prior art keywords
- local
- data
- model
- global
- client
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F9/00—Arrangements for program control, e.g. control units
- G06F9/06—Arrangements for program control, e.g. control units using stored programs, i.e. using an internal store of processing equipment to receive or retain programs
- G06F9/46—Multiprogramming arrangements
- G06F9/50—Allocation of resources, e.g. of the central processing unit [CPU]
- G06F9/5061—Partitioning or combining of resources
- G06F9/5066—Algorithms for mapping a plurality of inter-dependent sub-tasks onto a plurality of physical CPUs
Landscapes
- Engineering & Computer Science (AREA)
- Software Systems (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种面向数据非独立同分布场景的联邦学习方法和系统,包括多个客户端和中心服务器;中心服务器用于将目标数据集以非独立分布方式划分成多个子数据集使得每个子数据集包含所有种类数据,并分配子数据集到客户端;客户端用于基于接收的子数据集,依据当前本地锚点指导子数据集训练当前本地模型,并更新本地锚点和本地模型参数,依据约定的通信方式上传模型数据至中心服务器;中心服务器还用于根据接收的模型数据进行聚合得到聚合数据,依据约定的通信方式下传聚合数据至客户端以作为下一轮联邦学习的基础,该方法在保障用户数据安全的基础上,提升特定场景下联邦学习系统的实用性,同时解决联邦学习系统的通信效率、统计异构问题。
Description
技术领域
本发明属于人工智能和信息安全领域,尤其涉及一种面向数据非独立同分布场景的联邦学习方法及系统。
背景技术
随着大数据、人工智能、云计算等新技术在各行业不断深入应用,全球数据呈现爆发增长、海量集聚的特点,数据的价值愈发凸显。数据作为生产要素的流通交易,面临确权和隐私保护两大关键难题。数据本质上是信息,不具备独享性或专享性,多数人可同时占有。数字经济时代,与个人有关的信息传播边际成本几乎为零,能够迅速传遍整个世界,这种低成本使得数据保护面临特殊困难。目前,公司和组织等越来越多地收集用户的详细信息,一方面,能通过这些属于不同组织的原始数据抽取出有价值的信息,这些信息能通过机器学习技术来提升产品、服务和福利的质量;另一方面,在分布式场景下会存在潜在的滥用和攻击行为,这对数据隐私和安全提出了极大地挑战。传统从用户端收集数据整合后训练机器学习模型的方式被担忧可能会侵犯隐私。
联邦学习(Federated Learning,FL)提供了一种灵活的解决方式,允许机器学习应用以一种保留隐私数据在本地的方式进行分布式学习,在不违反法律和道德要求的情况下,使多个隐私数据拥有方参与合作训练共同的模型。联邦学习允许不同组织和设备共同参与训练构建机器学习模型,这些组织和设备使用的数据存储在本地。联邦学习算法在保证模型训练过程的同时不侵犯用户隐私数据,用户在本地挖掘数据信息的价值,最终构建和使用机器学习模型。
McMahan等人提出的FedAvg算法是目前最广泛采用的联邦学习的基础,它通过客户端在本地训练多次迭代来降低通信成本。联邦学习涉及的主要挑战包括通信效率、系统异质性、统计异质性和隐私。为了降低联邦学习中的通信成本,一些研究建议使用quantization和sketching等数据压缩技术,还有一些人建议采用split learning的方式。为了解决系统异构问题,相继提出了异步通信和客户端主动采样技术。统计异质性是当前联邦学习研究的热点。一个研究趋势是调整全局模型以适应数据非独立同分布的个性化局部模型,例如,通过将联邦学习与assisted learning、meta-learning、multi-tasklearning、transfer learning、knowledge distillation一些学习方法结合起来。然而,这些个性化方法通常会引入可能不必要的额外计算和通信开销,最终的效果也不尽如人意。
发明内容
鉴于上述,本发明的目的在于提供一种面向数据非独立同分布场景的联邦学习方法及系统,在保障用户数据安全的基础上,进一步提升特定场景下联邦学习系统的实用性,同时解决联邦学习系统的通信效率、统计异构问题。
为实现上述发明目的,实施例提供了一种面向数据非独立同分布场景的联邦学习系统,包括多个客户端和中心服务器,每个客户端与中心服务器建有通信通道;
中心服务器用于将目标数据集以非独立分布方式划分成多个子数据集使得每个子数据集包含所有种类数据,还用于分配子数据集到客户端使得每个客户端均拥有1个子数据集;
客户端用于基于接收的子数据集,依据当前全局锚点指导子数据集训练当前本地模型,并更新本地锚点和本地模型参数,依据约定的通信方式上传模型数据至中心服务器,其中,本地锚点为根据本地数据对应的特征向量确定的能够代表分类类别的特征点,模型数据包括新本地锚点、本地模型参数;
中心服务器还用于根据接收的模型数据进行聚合得到聚合数据,依据约定的通信方式下传聚合数据至客户端以作为下一轮联邦学习的基础,其中,聚合数据包括对所有新本地锚点聚合得到的全局锚点、对所有本地模型参数聚合得到的全局模型参数。
在一个实施例中,中心服务器将目标数据集以非独立分布方式划分成多个子数据集时,采用基于标签的非独立分布方式,包括:
首先,根据联邦学习系统中客户端数量与数据标签种类设置双随机矩阵的行参数和列参数;
然后,根据设置的行参数和列参数生成双随机矩阵;
最后,将目标数据集按照生成的双随机矩阵进行数据划分,以使每个子数据集包含所有种类数据。
在一个实施例中,客户端依据当前全局锚点指导子数据集训练当前本地模型时,采用的损失函数total loss为:
total loss=base loss+α*anchor loss+β*triplet loss
其中,α和β为超参数,base loss为基于本地模型输出的分类预测结果构建关于分类预测任务的任务损失;
anchor loss为基于当前本地锚点与每类数据对应的特征向量构建指导本地模型训练的指导损失;
triplet loss为不同类别的特征向量之间的对比损失。
在一个实施例中,指导损失anchor loss表示为:
di=distance(embeddingi,local anchori)
其中,i为类别索引,k为总类别数,yi表示第i类别的标签,di表示第i类数据所生成的特征向量embeddingi与当前本地锚点local anchori之间的距离,distance()表示距离计算操作,包括欧式距离计算操作或余弦相似度操作,CrossEntropyLoss()表示交叉熵损失函数;
对比损失triplet loss表示为:
dij=distance(embeddingi,embeddingj)
其中,j为类别索引,yj表示第j类别的标签,margin表示预设阈值,取值为0-1,max{}表示求最大值函数,dij表示第i类数据所生成的特征向量embeddingi与第j类数据所生成的特征向量embeddingj之间的距离。
在一个实施例中,客户端在训练本地模型时,采用动量优化算法更新本地锚点,具体公式为:
local anchori’=γ*embeddingi+(1-γ)*local anchori,i=1,...,k
其中,i为类别索引,k为总类别数,local anchori’表示第i类数据对应的新本地锚点,local anchori表示第i类数据对应的当前本地锚点,embeddingi表示第i类数据所生成的特征向量embeddingi,γ表示锚点更新动量,为本地训练过程中的每批数据数量batchsize与本地训练过程中数据总量data size的比值,即
在一个实施例中,中心服务器对所有本地模型参数聚合采用的聚合方式为:
其中,n表示本地模型索引,N为本地模型总个数,global model表示全局模型参数,local modeln表示第n个本地模型参数,local weightn表示第n个本地模型参数的聚合权重,local datan表示训练第n个本地模型的子数据集的数据量,total data表示目标数据集的数据量。
在一个实施例中,中心服务器对所有新本地锚点聚合采用的聚合方式为:
其中,global anchor表示全局锚点,local anchornj表示第n个本地模型所对应的第j类标签的新本地锚点,anchor weightnj表示第n个本地模型所对应的第j类标签的锚点权重,计算方式如下:
在一个实施例中,在联邦学习之前,中心服务器还用于构建联邦学习任务并下发至各客户端,其中,联邦学习任务包括初始本地模型、初始本地锚点、联邦学习超参数,其中,联邦学习超参数包括梯度下降算法优化器、本地训练学习率、参与联邦学习的客户端总量、客户端每轮本地训练的轮次、总通信轮次、本地锚点的空间维度、权值衰减率、约定的通信方式;客户端与中心服务器依据联邦学习任务进行联邦学习;
其中,初始本地锚点个数与目标数据集的标签种类相等;
本地模型包括特征提取网络、投影网络以及输出层,其中,特征提取网络用于提取输入样本数据的特征图,投影网络用于将特征图投影到特征投影空间以得到特征向量,输出层用于根据特征向量进行预测计算并数据分类预测结果。
在一个实施例中,所述通信方式包括4种:
第一种,上传的模型数据包括本地模型参数和新本地锚点,下传的聚合数据包括全局模型参数和全局锚点;
第二种,上传的模型数据包括新本地锚点,下传的聚合数据包括全局模型参数和全局锚点;
第三种,上传的模型数据包括本地模型参数和新本地锚点,下传的聚合数据包括全局锚点;
第四种,上传的模型数据包括新本地锚点,下传的聚合数据包括全局锚点;
中心服务器与客户端选择第一种通信方式通信时,客户端在本地训练模型后,同时上传本地模型参数和新本地锚点至中心服务器;中心服务器对所有本地模型参数和新本地锚点分别进行聚合以得到全局模型参数和全局锚点,并下传全局模型参数和全局锚点至客户端作为当前本地模型和当前全局锚点以进行下一轮本地训练;
中心服务器与客户端选择第二种通信方式通信时,客户端在本地训练模型后,上传新本地锚点至中心服务器;中心服务器对所有新本地锚点进行聚合以得到全局锚点,并下传历史存储的全局模型参数和全局锚点至客户端作为当前本地模型和当前全局锚点以进行下一轮本地训练;
中心服务器与客户端选择第三种通信方式通信时,客户端在本地训练模型后,同时上传本地模型参数和新本地锚点至中心服务器;中心服务器对所有本地模型参数和新本地锚点分别进行聚合以得到全局模型参数和全局锚点,并下传全局锚点至客户端作为当前全局锚点以进行下一轮本地训练;
中心服务器与客户端选择第四种通信方式通信时,客户端在本地训练模型后,上传新本地锚点至中心服务器;中心服务器对所有本地锚点进行聚合以得到全局锚点,并下传全局锚点至客户端作为当前全局锚点以进行下一轮本地训练。
为实现上述发明目的,实施例还提供了一种面向数据非独立同分布场景的联邦学习方法,所述联邦学习方法采用上述联邦学习系统,所述联邦学习方法包括:
中心服务器与客户端建立通信通道,中心服务器创建联邦学习任务,并下发至各客户端;
中心服务器将目标数据集以非独立分布方式划分成多个子数据集使得每个子数据集包含所有种类数据,并分配子数据集到客户端使得每个客户端均拥有1个子数据集;
客户端基于接收的子数据集,依据作为当前本地锚点的全局锚点指导子数据集训练当前本地模型,并更新本地锚点和本地模型参数,依据约定的通信方式上传模型数据至中心服务器,其中,本地锚点为每类数据对应的特征向量中能够代表分类类别的特征点,模型数据包括新本地锚点、本地模型参数;
中心服务器根据接收的模型数据进行聚合得到聚合数据,依据约定的通信方式下传聚合数据至客户端以作为下一轮联邦学习的基础,其中,聚合数据包括对所有新本地锚点聚合得到的全局锚点、对所有本地模型参数聚合得到的全局模型参数。
与现有技术相比,本发明具有的有益效果至少包括:
中心服务器采用非独立分布方式将目标数据集划分成多个子数据集使得每个子数据集包含所有种类数据,并下发每个子数据集到各客户端,使得各客户端拥有的子数据集既包含了所有种类数据,又不与其他客户端进行数据通信,形成了适用于联邦学习的场景,保证联邦学习场景下对目标数据集的充分利用,提成目标数据集的利用率和联邦学习的准确性;
客户端基于本地锚点指导子数据集训练当前本地模型,由于本地锚点是从本地数据对应的特征映射中提取的能够代表分类类别的特征点,因此,利用本地锚点指导和约束训练,能够提高联邦学习效率,同时中心服务器基于客户端上传的本地锚点进行聚合,然后再下发全局锚点至各客户端以进行下一轮联邦学习,这样全局锚点实现了对全局数据分布的指导和约束。
在联邦学习过程中,通过凝练本地数据知识的本地锚点代替本地模型上传至中心服务器进行全局信息交流,由于本地锚点相对于本地模型参数降低了至少1000个数量级,这样在保障模型训练精度的基础上大幅度地降低了通信成本,且根据通信方式能够使得中心参数服务器和本地客户端可以根据诉求进行自适应的调节具体的通信方式。
由于本地锚点是与数据的特征向量密切相关的能够代表每类分类标签的特征点,因此基于本地锚点指导的训练方式可以兼容到现有的各种联邦聚合算法,并不构成冲突,具有较好的兼容性。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图做简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动前提下,还可以根据这些附图获得其他附图。
图1是实施例提供的面向数据非独立同分布场景的联邦学习系统的流程图;
图2是实施例提供的联邦学习原理图,其中,①中心服务器向客户端传输全局模型和全局锚点。②客户端利用子数据集进行本地模型训练。③子数据集经过本地模型之后输出预测结果和特征向量embedding。④利用特征向量embedding更新本地锚点。⑤客户端向中心服务器传输本地模型参数和新本地锚点。⑥中心服务器用模型权重和锚点权重聚合本地模型参数和新本地锚点,获得新全局模型和全局锚点;
图3是实施例提供的面向数据非独立同分布场景的联邦学习方法的流程图。
具体实施方式
为使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例对本发明进行进一步的详细说明。应当理解,此处所描述的具体实施方式仅仅用以解释本发明,并不限定本发明的保护范围。
图1是实施例提供的面向数据非独立同分布场景的联邦学习系统的流程图。图2是实施例提供的联邦学习原理图。如图1和图2所示,实施例提供的面向数据非独立同分布场景的联邦学习系统,包括1个中心服务器和多个客户端,其中,客户端可以是任意具有算力的电子设备,包括但不限于智能手表、智能手机、各类电脑等。中心服务器可以设于终端的终端服务器,也可以是设于云端的云端服务器。中心服务器与各参与联邦学习的客户端建有通信通道,以实现数据和消息的安全通信。客户端与中心服务器的配合实现了数据非独立同分布场景下的联邦学习任务。
实施例中,中心服务器作为面向数据非独立同分布场景的联邦学习任务的指导者,用于将目标数据集以非独立分布方式划分成多个子数据集使得每个子数据集包含所有种类数据,还用于分配子数据集到客户端使得每个客户端均拥有1个子数据集。
现有方式中,采用基于数据的非独立分布方式,即将目标数据集的所有数据按照狄利克雷分布(Dirichlet distribution)随机采样,并将采样结果分配至各客户端,这样分配到各客户端的子数据集包含的数据种类不均匀,有些客户端被分配的子数据集包含全部类标的数据,有些客户端被分配的子数据集则不包含全部类标的数据,这会导致基于数据不均匀分布的子数据集联邦学习效果不好,得到的模型鲁棒性差。
为此,实施例中,中心服务器将目标数据集以非独立分布方式划分成多个子数据集时,采用基于标签的非独立分布方式,包括:首先,根据联邦学习系统中客户端数量与数据标签种类设置双随机矩阵的行参数和列参数;然后,根据设置的行参数和列参数生成双随机矩阵;最后,将目标数据集按照生成的双随机矩阵进行数据划分,以使每个子数据集包含所有种类数据。
实施例中,依据客户端数量与数据标签种类,确定双随机矩阵的维度(n,k),在概率和组合数学中,双随机矩阵是一个非负实数的方阵,其每一行和每一列的总和为1,即:
实施例中,采用BIRKHOFF-VON NEUMANN算法生成双随机矩阵。该算法表明,在双随机矩阵X中,存在θ1,...,θk≥0,和置换矩阵P1,...,Pk,使得X=θ1P1+…+θkPk。xij表示双随机矩阵中第i行第j列元素值,代表分给第i个客户端且属于第k类标签的数据概率,n和k分别表示客户端数量和数据标签种类。具体地,在生成双随机矩阵X时,基于Dirichlet distribution的随机采样生成θ1,...,θk≥0,基于PCG-64伪随机数生成器生成置换矩阵P1,...,Pk,PCG-64是一个O’Neill排列同余生成器的128位实现。这样生成的双随机矩阵X满足可用于目标数据集的划分。
实施例中,利用生成的双随机矩阵X对目标数据集进行数据划分,以使每个非独立同分布的子数据集包含所有标签的数据,利于非独立同分布场景的联邦学习。
实施例中,中心服务器作为联邦学习系统的指导者,还用于构建联邦学习任务,并下发至各客户端以组织进行联邦学习。其中,联邦学习任务包括初始本地模型、初始全局锚点、联邦学习超参数,其中,联邦学习超参数包括梯度下降算法优化器、本地训练学习率、参与联邦学习的客户端总量、客户端每轮本地训练的轮次、总通信轮次、全局锚点的空间维度、权值衰减率、约定的通信方式、训练终止条件等;客户端与中心服务器依据联邦学习任务进行联邦学习;训练终止条件包括达到总通讯轮次或全局模型收敛。
实施例中,采用上述PCG-64伪随机数生成器并根据数据标签种类随机初始化全局锚点,以生成(n,k,a)维矩阵。其中,n表示客户端数量,k表示目标数据集的标签种类,a表示全局锚点维度,且初始全局锚点个数与目标数据集的标签种类相等。可以设置全局锚点维度为64、128、256。
本地模型包括特征提取网络、投影网络以及输出层,其中,特征提取网络用于提取输入样本数据的特征图,投影网络用于将特征图投影到限定维度的特征投影空间,得到对应的特征向量。输出层用于根据特征向量进行预测计算并数据分类预测结果。优选地,特征提取网络可以是Simple-CNN网络的特征提取网络、Vgg-11的特征提取网络、ResNet-18的特征提取网络、基于CNN的特征提取网络以及ResNet-54的特征提取网络等。投影网络可以采用MLP网络。输出层可以采用全连接神经网络。
客户端用于基于接收的子数据集,依据当前全局锚点指导值数据集训练当前本地模型,并更新本地锚点和本地模型参数,依据约定的通信方式上传模型数据至中心服务器。
实施例中,客户端在进行初始轮次的联邦学习时,当前全局锚点为中心服务器下发的初始全局锚点;在进行非初始轮次的其他轮次联邦学习时,当前轮次的全局锚点为中心服务器下发的全局锚点。本地锚点为根据本地数据对应的特征向量确定的能够代表分类类别的特征点,模型数据包括新本地锚点、本地模型参数。
实施例中,客户端依据当前全局锚点指导子数据集训练当前本地模型时,采用的损失函数total loss为:
total loss=base loss+α*anchor loss+β*triplet loss
anchor loss为基于当前全局锚点与每类数据对应的特征向量构建指导本地模型训练的指导损失,优选地,指导损失anchor loss表示为:
di=distance(embeddingi,global anchori)
其中,i为类别索引,k为总类别数,yi表示第i类别的标签,di表示第i类数据所生成的特征向量embeddingi与当前全局锚点global anchori之间的距离,可以采用1范数计算,distance()表示距离计算操作,包括欧式距离计算操作或余弦相似度操作,CrossEntropyLoss()表示交叉熵损失函数。
triplet loss为不同类别的特征向量之间的对比损失。优选地,对比损失tripletloss表示为:
dij=distance(embeddingi,embeddingj)
其中,j为类别索引,yj表示第j类别的标签,margin表示预设阈值,取值为0-1,max{}表示求最大值函数,dij表示第i类数据所生成的特征向量e ingi与第j类数据所生成的特征向量embeddingj之间的距离。
实施例中,客户端在训练本地模型时,采用损失函数total loss并进行反向传播来更新本地模型参数,同时采用动量优化算法更新本地锚点,具体公式为:
local anchori’=γ*embeddingi+(1-γ)*local anchori,i=1,...,k
其中,i为类别索引,k为总类别数,local anchori’表示第i类数据对应的新本地锚点,γ表示锚点更新动量,为本地训练过程中的每批数据数量batch size与本地训练过程中数据总量data size的比值,即embeddingi表示第i类数据所生成的特征向量embeddingi,local anchori表示第i类数据对应的当前本地锚点,需要说明的是,当前本地锚点为上一轮中心服务器下发的全局锚点,即客户端的local anchor在每一个通讯轮次,初始化为global anchor,表示为local anchori=global anchor。
在联邦学习任务中,中心服务器还用于根据接收的模型数据进行聚合得到聚合数据,依据约定的通信方式下传聚合数据至客户端以作为下一轮联邦学习的基础。聚合数据包括对所有新本地锚点聚合得到的全局锚点、对所有本地模型参数聚合得到的全局模型参数。
实施例中,中心服务器对所有本地模型参数聚合采用的聚合方式为:
其中,n表示本地模型索引,N为本地模型总个数,global model表示全局模型参数,local modeln表示第n个本地模型参数,local weightn表示第n个本地模型参数的聚合权重,local datan表示训练第n个本地模型的子数据集的数据量,total data表示目标数据集的数据量。
根据子数据集和目标数据集的数据量比值作为模型聚合权重,该聚合权重更具有加权价值,以得到更加准确的全局模型。
实施例中,中心服务器对所有新本地锚点聚合采用的聚合方式为:
其中,global anchor表示全局锚点,local anchornj表示第n个本地模型所对应的第j类标签的新本地锚点,anchor weightnj表示第n个本地模型所对应的第j类标签的锚点权重,计算方式如下:
根据每个客户端的每类标签的数据量与该类标签在目标数据集中的总数据量确定锚点权重,该锚点权重更具有加权价值,以得到更加准确的全局锚点。
在联邦学习系统中,提供4种通信方式,在联邦学习时,客户端与中心服务器协商通信方式,后续联邦学习及按照协商约定的通信方式进行模型数据的上传和聚合数据的下发。
实施例中,提供以下4种通信方式:
第一种,上传的模型数据包括本地模型参数和新本地锚点,下传的聚合数据包括全局模型参数和全局锚点。中心服务器与客户端选择第一种通信方式通信时,客户端在本地训练模型后,同时上传本地模型参数和新本地锚点至中心服务器;中心服务器对所有本地模型参数和新本地锚点分别进行聚合以得到全局模型参数和全局锚点,并下传全局模型参数和全局锚点至客户端作为当前本地模型和当前全局锚点以进行下一轮本地训练。
第二种,上传的模型数据包括新本地锚点,下传的聚合数据包括全局模型参数和全局锚点。中心服务器与客户端选择第二种通信方式通信时,客户端在本地训练模型后,上传新本地锚点至中心服务器;中心服务器对所有新本地锚点进行聚合以得到全局锚点,并下传历史存储的全局模型参数和全局锚点至客户端作为当前本地模型和当前全局锚点以进行下一轮本地训练。
第三种,上传的模型数据包括本地模型参数和新本地锚点,下传的聚合数据包括全局锚点。中心服务器与客户端选择第三种通信方式通信时,客户端在本地训练模型后,同时上传本地模型参数和新本地锚点至中心服务器;中心服务器对所有本地模型参数和新本地锚点分别进行聚合以得到全局模型参数和全局锚点,并下传全局锚点至客户端作为当前全局锚点以进行下一轮本地训练。
第四种,上传的模型数据包括新本地锚点,下传的聚合数据包括全局锚点。中心服务器与客户端选择第四种通信方式通信时,客户端在本地训练模型后,上传新本地锚点至中心服务器;中心服务器对所有本地锚点进行聚合以得到全局锚点,并下传全局锚点至客户端作为当前全局锚点以进行下一轮本地训练。
需要说明的是,在整个联邦学习过程中,可以采用不同的通信方式,例如,通讯轮次为100轮,其中第1~30轮采用第一种通信方式,即下传/上传模型参数和锚点的通讯方式,第31~60轮采用第二种通信方式,即下传模型参数和锚点、上传锚点的通讯方式,第61~100轮采用第四种通信方式,即下传/上传锚点的通讯方式。联邦学习系统中的各个客户端与中心服务器按照约定的通讯方式进行多轮通讯,经过高效的联邦聚合过程,实现高精度的全局模型。
图3是实施例提供的面向数据非独立同分布场景的联邦学习方法的流程图。如图3所示,实施例还提供了一种面向数据非独立同分布场景的联邦学习方法,该联邦学习方法采用上述联邦学习系统,所述联邦学习方法包括:
步骤1,中心服务器与客户端建立通信通道,中心服务器创建联邦学习任务,并下发至各客户端;
步骤2,中心服务器将目标数据集以非独立分布方式划分成多个子数据集使得每个子数据集包含所有种类数据,并分配子数据集到客户端使得每个客户端均拥有1个子数据集;
步骤3,客户端基于接收的子数据集,依据作为当前本地锚点的全局锚点指导子数据集训练当前本地模型,并更新本地锚点和本地模型参数,依据约定的通信方式上传模型数据至中心服务器;
其中,本地锚点为每类数据对应的特征向量中能够代表分类类别的特征点,模型数据包括新本地锚点、本地模型参数;
步骤4,中心服务器根据接收的模型数据进行聚合得到聚合数据,依据约定的通信方式下传聚合数据至客户端以作为下一轮联邦学习的基础;
其中,聚合数据包括对所有新本地锚点聚合得到的全局锚点、对所有本地模型参数聚合得到的全局模型参数。
以上所述的具体实施方式对本发明的技术方案和有益效果进行了详细说明,应理解的是以上所述仅为本发明的最优选实施例,并不用于限制本发明,凡在本发明的原则范围内所做的任何修改、补充和等同替换等,均应包含在本发明的保护范围之内。
Claims (10)
1.一种面向数据非独立同分布场景的联邦学习系统,其特征在于,包括多个客户端和中心服务器,每个客户端与中心服务器建有通信通道;
中心服务器用于将目标数据集以非独立分布方式划分成多个子数据集使得每个子数据集包含所有种类数据,还用于分配子数据集到客户端使得每个客户端均拥有1个子数据集;
客户端用于基于接收的子数据集,依据当前全局锚点指导子数据集训练当前本地模型,并更新本地锚点和本地模型参数,依据约定的通信方式上传模型数据至中心服务器,其中,本地锚点为根据本地数据对应的特征向量确定的能够代表分类类别的特征点,模型数据包括新本地锚点、本地模型参数;
中心服务器还用于根据接收的模型数据进行聚合得到聚合数据,依据约定的通信方式下传聚合数据至客户端以作为下一轮联邦学习的基础,其中,聚合数据包括对所有新本地锚点聚合得到的全局锚点、对所有本地模型参数聚合得到的全局模型参数。
2.根据权利要求1所述的面向数据非独立同分布场景的联邦学习系统,其特征在于,中心服务器将目标数据集以非独立分布方式划分成多个子数据集时,采用基于标签的非独立分布方式,包括:
首先,根据联邦学习系统中客户端数量与数据标签种类设置双随机矩阵的行参数和列参数;
然后,根据设置的行参数和列参数生成双随机矩阵;
最后,将目标数据集按照生成的双随机矩阵进行数据划分,以使每个子数据集包含所有种类数据。
3.根据权利要求1所述的面向数据非独立同分布场景的联邦学习系统,客户端依据当前全局锚点指导子数据集训练当前本地模型时,采用的损失函数total loss为:
total loss=base loss+α*anchor loss+β*triplet loss
其中,α和β为超参数,base loss为基于本地模型输出的分类预测结果构建关于分类预测任务的任务损失;
anchor loss为基于当前本地锚点与每类数据对应的特征向量构建指导本地模型训练的指导损失;
triplet loss为不同类别的特征向量之间的对比损失。
4.根据权利要求3所述的面向数据非独立同分布场景的联邦学习系统,指导损失anchor loss表示为:
di=distance(embeddingi,local anchori)
其中,i为类别索引,k为总类别数,yi表示第i类别的标签,di表示第i类数据所生成的特征向量embeddingi与当前本地锚点local anchori之间的距离,distance()表示距离计算操作,包括欧式距离计算操作或余弦相似度操作,CrossEntropyLoss()表示交叉熵损失函数;
对比损失triplet loss表示为:
dij=distance(embeddingi,embeddingj)
其中,j为类别索引,yj表示第j类别的标签,margin表示预设阈值,取值为0-1,max{}表示求最大值函数,dij表示第i类数据所生成的特征向量embeddingi与第j类数据所生成的特征向量embeddingj之间的距离。
8.根据权利要求1所述的面向数据非独立同分布场景的联邦学习系统,在联邦学习之前,中心服务器还用于构建联邦学习任务并下发至各客户端,其中,联邦学习任务包括初始本地模型、初始本地锚点、联邦学习超参数,其中,联邦学习超参数包括梯度下降算法优化器、本地训练学习率、参与联邦学习的客户端总量、客户端每轮本地训练的轮次、总通信轮次、本地锚点的空间维度、权值衰减率、约定的通信方式;客户端与中心服务器依据联邦学习任务进行联邦学习;
其中,初始本地锚点个数与目标数据集的标签种类相等;
本地模型包括特征提取网络、投影网络以及输出层,其中,特征提取网络用于提取输入样本数据的特征图,投影网络用于将特征图投影到特征投影空间以得到特征向量,输出层用于根据特征向量进行预测计算并数据分类预测结果。
9.根据权利要求1所述的面向数据非独立同分布场景的联邦学习系统,所述通信方式包括4种:
第一种,上传的模型数据包括本地模型参数和新本地锚点,下传的聚合数据包括全局模型参数和全局锚点;
第二种,上传的模型数据包括新本地锚点,下传的聚合数据包括全局模型参数和全局锚点;
第三种,上传的模型数据包括本地模型参数和新本地锚点,下传的聚合数据包括全局锚点;
第四种,上传的模型数据包括新本地锚点,下传的聚合数据包括全局锚点;
中心服务器与客户端选择第一种通信方式通信时,客户端在本地训练模型后,同时上传本地模型参数和新本地锚点至中心服务器;中心服务器对所有本地模型参数和新本地锚点分别进行聚合以得到全局模型参数和全局锚点,并下传全局模型参数和全局锚点至客户端作为当前本地模型和当前全局锚点以进行下一轮本地训练;
中心服务器与客户端选择第二种通信方式通信时,客户端在本地训练模型后,上传新本地锚点至中心服务器;中心服务器对所有新本地锚点进行聚合以得到全局锚点,并下传历史存储的全局模型参数和全局锚点至客户端作为当前本地模型和当前全局锚点以进行下一轮本地训练;
中心服务器与客户端选择第三种通信方式通信时,客户端在本地训练模型后,同时上传本地模型参数和新本地锚点至中心服务器;中心服务器对所有本地模型参数和新本地锚点分别进行聚合以得到全局模型参数和全局锚点,并下传全局锚点至客户端作为当前全局锚点以进行下一轮本地训练;
中心服务器与客户端选择第四种通信方式通信时,客户端在本地训练模型后,上传新本地锚点至中心服务器;中心服务器对所有本地锚点进行聚合以得到全局锚点,并下传全局锚点至客户端作为当前全局锚点以进行下一轮本地训练。
10.一种面向数据非独立同分布场景的联邦学习方法,其特征在于,所述联邦学习方法采用权利要求1-9任一项所述的联邦学习系统,所述联邦学习方法包括:
中心服务器与客户端建立通信通道,中心服务器创建联邦学习任务,并下发至各客户端;
中心服务器将目标数据集以非独立分布方式划分成多个子数据集使得每个子数据集包含所有种类数据,并分配子数据集到客户端使得每个客户端均拥有1个子数据集;
客户端基于接收的子数据集,依据作为当前本地锚点的全局锚点指导子数据集训练当前本地模型,并更新本地锚点和本地模型参数,依据约定的通信方式上传模型数据至中心服务器,其中,本地锚点为每类数据对应的特征向量中能够代表分类类别的特征点,模型数据包括新本地锚点、本地模型参数;
中心服务器根据接收的模型数据进行聚合得到聚合数据,依据约定的通信方式下传聚合数据至客户端以作为下一轮联邦学习的基础,其中,聚合数据包括对所有新本地锚点聚合得到的全局锚点、对所有本地模型参数聚合得到的全局模型参数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210192242.XA CN114580663A (zh) | 2022-03-01 | 2022-03-01 | 面向数据非独立同分布场景的联邦学习方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210192242.XA CN114580663A (zh) | 2022-03-01 | 2022-03-01 | 面向数据非独立同分布场景的联邦学习方法及系统 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114580663A true CN114580663A (zh) | 2022-06-03 |
Family
ID=81777589
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210192242.XA Pending CN114580663A (zh) | 2022-03-01 | 2022-03-01 | 面向数据非独立同分布场景的联邦学习方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114580663A (zh) |
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114863499A (zh) * | 2022-06-30 | 2022-08-05 | 广州脉泽科技有限公司 | 一种基于联邦学习的指静脉与掌静脉识别方法 |
CN115496204A (zh) * | 2022-10-09 | 2022-12-20 | 南京邮电大学 | 一种跨域异质场景下的面向联邦学习的评测方法及装置 |
CN115511108A (zh) * | 2022-09-27 | 2022-12-23 | 河南大学 | 一种基于数据集蒸馏的联邦学习个性化方法 |
CN115659212A (zh) * | 2022-09-27 | 2023-01-31 | 南京邮电大学 | 跨域异质场景下基于tdd通信的联邦学习效率评测方法 |
CN116204599A (zh) * | 2023-05-06 | 2023-06-02 | 成都三合力通科技有限公司 | 基于联邦学习的用户信息分析系统及方法 |
CN116541712A (zh) * | 2023-06-26 | 2023-08-04 | 杭州金智塔科技有限公司 | 基于非独立同分布数据的联邦建模方法及系统 |
WO2024022082A1 (zh) * | 2022-07-29 | 2024-02-01 | 脸萌有限公司 | 信息分类的方法、装置、设备和介质 |
CN117708681A (zh) * | 2024-02-06 | 2024-03-15 | 南京邮电大学 | 基于结构图指导的个性化联邦脑电信号分类方法及系统 |
-
2022
- 2022-03-01 CN CN202210192242.XA patent/CN114580663A/zh active Pending
Cited By (14)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114863499A (zh) * | 2022-06-30 | 2022-08-05 | 广州脉泽科技有限公司 | 一种基于联邦学习的指静脉与掌静脉识别方法 |
CN114863499B (zh) * | 2022-06-30 | 2022-12-13 | 广州脉泽科技有限公司 | 一种基于联邦学习的指静脉与掌静脉识别方法 |
WO2024022082A1 (zh) * | 2022-07-29 | 2024-02-01 | 脸萌有限公司 | 信息分类的方法、装置、设备和介质 |
CN115511108A (zh) * | 2022-09-27 | 2022-12-23 | 河南大学 | 一种基于数据集蒸馏的联邦学习个性化方法 |
CN115659212A (zh) * | 2022-09-27 | 2023-01-31 | 南京邮电大学 | 跨域异质场景下基于tdd通信的联邦学习效率评测方法 |
CN115659212B (zh) * | 2022-09-27 | 2024-04-09 | 南京邮电大学 | 跨域异质场景下基于tdd通信的联邦学习效率评测方法 |
CN115496204A (zh) * | 2022-10-09 | 2022-12-20 | 南京邮电大学 | 一种跨域异质场景下的面向联邦学习的评测方法及装置 |
CN115496204B (zh) * | 2022-10-09 | 2024-02-02 | 南京邮电大学 | 一种跨域异质场景下的面向联邦学习的评测方法及装置 |
CN116204599A (zh) * | 2023-05-06 | 2023-06-02 | 成都三合力通科技有限公司 | 基于联邦学习的用户信息分析系统及方法 |
CN116204599B (zh) * | 2023-05-06 | 2023-10-20 | 成都三合力通科技有限公司 | 基于联邦学习的用户信息分析系统及方法 |
CN116541712A (zh) * | 2023-06-26 | 2023-08-04 | 杭州金智塔科技有限公司 | 基于非独立同分布数据的联邦建模方法及系统 |
CN116541712B (zh) * | 2023-06-26 | 2023-12-26 | 杭州金智塔科技有限公司 | 基于非独立同分布数据的联邦建模方法及系统 |
CN117708681A (zh) * | 2024-02-06 | 2024-03-15 | 南京邮电大学 | 基于结构图指导的个性化联邦脑电信号分类方法及系统 |
CN117708681B (zh) * | 2024-02-06 | 2024-04-26 | 南京邮电大学 | 基于结构图指导的个性化联邦脑电信号分类方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114580663A (zh) | 面向数据非独立同分布场景的联邦学习方法及系统 | |
CN112364943B (zh) | 一种基于联邦学习的联邦预测方法 | |
CN110084377B (zh) | 用于构建决策树的方法和装置 | |
US20230237326A1 (en) | Data processing method and apparatus | |
WO2023071626A1 (zh) | 一种联邦学习方法、装置、设备、存储介质及产品 | |
CN113011646B (zh) | 一种数据处理方法、设备以及可读存储介质 | |
Liu et al. | Keep your data locally: Federated-learning-based data privacy preservation in edge computing | |
CN113128701A (zh) | 面向样本稀疏性的联邦学习方法及系统 | |
WO2023185539A1 (zh) | 机器学习模型训练方法、业务数据处理方法、装置及系统 | |
CN114298122B (zh) | 数据分类方法、装置、设备、存储介质及计算机程序产品 | |
CN114580662A (zh) | 基于锚点聚合的联邦学习方法和系统 | |
CN115130711A (zh) | 一种数据处理方法、装置、计算机及可读存储介质 | |
CN114282059A (zh) | 视频检索的方法、装置、设备及存储介质 | |
CN113821668A (zh) | 数据分类识别方法、装置、设备及可读存储介质 | |
CN115879542A (zh) | 一种面向非独立同分布异构数据的联邦学习方法 | |
Uddin et al. | Federated learning via disentangled information bottleneck | |
Matsuda et al. | An empirical study of personalized federated learning | |
CN110248195A (zh) | 用于输出信息的方法和装置 | |
Marnissi et al. | Client selection in federated learning based on gradients importance | |
CN116244484B (zh) | 一种面向不平衡数据的联邦跨模态检索方法及系统 | |
Chen et al. | Resource-aware knowledge distillation for federated learning | |
WO2023029944A1 (zh) | 联邦学习的方法和装置 | |
CN113962417A (zh) | 一种视频处理方法、装置、电子设备和存储介质 | |
Tang et al. | Optimizing federated learning on non-IID data using local Shapley value | |
CN115936110A (zh) | 一种缓解异构性问题的联邦学习方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |