CN116776155B - 一种基于联邦学习的模型训练方法、装置、设备和介质 - Google Patents
一种基于联邦学习的模型训练方法、装置、设备和介质 Download PDFInfo
- Publication number
- CN116776155B CN116776155B CN202310870506.7A CN202310870506A CN116776155B CN 116776155 B CN116776155 B CN 116776155B CN 202310870506 A CN202310870506 A CN 202310870506A CN 116776155 B CN116776155 B CN 116776155B
- Authority
- CN
- China
- Prior art keywords
- model
- gradient
- processed
- target
- determining
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 184
- 238000000034 method Methods 0.000 title claims abstract description 54
- 238000012216 screening Methods 0.000 claims abstract description 111
- 230000002776 aggregation Effects 0.000 claims description 23
- 238000004220 aggregation Methods 0.000 claims description 23
- 238000012545 processing Methods 0.000 claims description 19
- 230000004931 aggregating effect Effects 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 4
- 238000012163 sequencing technique Methods 0.000 claims description 2
- 230000000694 effects Effects 0.000 abstract description 22
- 239000013598 vector Substances 0.000 description 14
- 230000008569 process Effects 0.000 description 11
- 238000004458 analytical method Methods 0.000 description 9
- 230000006870 function Effects 0.000 description 9
- 230000007246 mechanism Effects 0.000 description 9
- 230000004044 response Effects 0.000 description 9
- 230000003287 optical effect Effects 0.000 description 6
- 238000010586 diagram Methods 0.000 description 5
- 238000005516 engineering process Methods 0.000 description 4
- 238000005457 optimization Methods 0.000 description 4
- 238000003860 storage Methods 0.000 description 4
- 238000009826 distribution Methods 0.000 description 3
- 238000012935 Averaging Methods 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 238000012804 iterative process Methods 0.000 description 2
- 230000002093 peripheral effect Effects 0.000 description 2
- 238000006116 polymerization reaction Methods 0.000 description 2
- 230000000644 propagated effect Effects 0.000 description 2
- 101100481876 Danio rerio pbk gene Proteins 0.000 description 1
- 101100391182 Dictyostelium discoideum forI gene Proteins 0.000 description 1
- 101100481878 Mus musculus Pbk gene Proteins 0.000 description 1
- 230000002411 adverse Effects 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 230000002301 combined effect Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 230000005251 gamma ray Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 230000000379 polymerizing effect Effects 0.000 description 1
- 230000008707 rearrangement Effects 0.000 description 1
- 230000002040 relaxant effect Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/602—Providing cryptographic facilities or services
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Bioethics (AREA)
- Evolutionary Computation (AREA)
- Computer Hardware Design (AREA)
- Medical Informatics (AREA)
- Computer Security & Cryptography (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Databases & Information Systems (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Complex Calculations (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明实施例公开了一种基于联邦学习的模型训练方法、装置、设备和介质,方法包括:获取预设服务端发送的当前全局模型,并基于预设训练样本对当前全局模型训练的训练结果确定待处理模型梯度集合;基于待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据预设参数估计值确定目标梯度筛选阈值;根据目标梯度筛选阈值对待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;将目标模型梯度集合发送至预设服务端,以使预设服务端基于目标模型梯度集合对当前全局模型进行更新,得到目标全局模型。本发明实施例的技术方案可以动态确定梯度筛选阈值,提高梯度筛选效果和模型训练性能。
Description
技术领域
本发明实施例涉及机器学习技术领域,尤其涉及一种基于联邦学习的模型训练方法、装置、设备和介质。
背景技术
由于基于联邦学习的模型训练方法存在客户隐私泄露的问题,因此现有技术通过设置梯度阈值来对重要的梯度进行筛选,再对筛选出的梯度进行加密,进而解决客户隐私泄露的问题。但是,现有基于联邦学习的模型训练方法往往根据经验来设置每一轮用于更新模型的梯度数量(K),并将其作为算法的超参数,即在训练过程中不发生变化。然而,在模型训练过程中,不同时期的梯度绝对值会存在较大差异,使用同一K值对梯度进行筛选,不能达到最优的选择效果和训练性能。
发明内容
本发明实施例提供了一种基于联邦学习的模型训练方法、装置、设备和介质,可以动态确定梯度筛选阈值,提高梯度筛选效果和模型训练性能。
第一方面,本发明实施例提供了一种基于联邦学习的模型训练方法,该方法包括:
获取预设服务端发送的当前全局模型,并基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合;
基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值;
根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;
将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型。
第二方面,本发明实施例提供了一种基于联邦学习的模型训练装置,该装置包括:
待处理模型梯度集合确定模块,用于获取预设服务端发送的当前全局模型,并基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合;
目标梯度筛选阈值确定模块,用于基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值;
模型梯度处理模块,用于根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;
目标模型梯度集合发送模块,用于将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型。
第三方面,本发明实施例提供了一种计算机设备,该计算机设备包括:
一个或多个处理器;
存储器,用于存储一个或多个程序;
当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现任一实施例所述的一种基于联邦学习的模型训练方法。
第四方面,本发明实施例提供了一种计算机可读介质,其上存储有计算机程序,该程序被处理器执行时实现任一实施例所述的一种基于联邦学习的模型训练方法。
本发明实施例所提供的技术方案,通过获取预设服务端发送的当前全局模型,并基于预设训练样本对当前全局模型训练的训练结果确定待处理模型梯度集合;基于待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据预设参数估计值确定目标梯度筛选阈值;根据目标梯度筛选阈值对待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;将目标模型梯度集合发送至预设服务端,以使预设服务端基于目标模型梯度集合对当前全局模型进行更新,得到目标全局模型。本发明实施例的技术方案解决了现有基于联邦学习的模型训练技术中梯度筛选阈值固定不变,导致梯度筛选效果和模型训练性能不足的问题,可以动态确定梯度筛选阈值,提高梯度筛选效果和模型训练性能。
附图说明
图1是本发明实施例提供的一种基于联邦学习的模型训练方法流程图;
图2是本发明实施例提供的又一种基于联邦学习的模型训练方法流程图;
图3是本发明实施例提供的一种基于联邦学习的模型训练装置的结构示意图;
图4是本发明实施例提供的一种计算机设备的结构示意图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
图1是本发明实施例提供的一种基于联邦学习的模型训练方法流程图,本发明实施例可适用于基于联邦学习进行模型训练的场景中,该方法可以由基于联邦学习的模型训练装置执行,该装置可以由软件和/或硬件的方式来实现。
如图1所示,基于联邦学习的模型训练方法包括以下步骤:
S110、获取预设服务端发送的当前全局模型,并基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合。
其中,预设训练样本可以是预设的用于训练全局模型的样本。具体的,预设训练样本可以是预设客户端上的一些本地数据集合。当前全局模型可以是当前正在进行训练的全局模型。在全局模型的训练过程中会对其进行多次迭代更新,在每一次迭代过程中,可以将在当前正在迭代更新的全局模型作为当前全局模型。进一步的,当前全局模型可以由预设服务端发送至预设客户端,预设客户端接收到当前全局模型后,可以根据预设训练样本对当前全局模型进行训练,得到对应的模型梯度集合。
待处理模型梯度集合可以是需要进行后续处理并用于更新当前全局模型的模型梯度集合。待处理模型梯度集合可以由预设训练样本对当前全局模型训练的训练结果进行确定。具体的,可以基于该训练结果和和当前模型梯度集合对应的上一模型误差梯度集合确定待处理模型梯度集合。
其中,上一模型误差梯度集合可以是在上一次全局模型迭代更新过程中,没有被筛选中的梯度集合。具体的,可以对两个集合中相同维度的模型梯度进行加合,并将加和后的模型梯度集合作为待处理模型梯度集合。
S120、基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值。
其中,预设参数估计值可以是预设的关于当前全局模型和预设训练样本的参数估计值。预设参数估计值可以作为确定梯度筛选阈值过程中的一个中间参数,用于确定目标梯度筛选阈值。
目标梯度筛选阈值可以是用于对待处理模型梯度进行筛选的参考阈值。具体的,可以根据待处理模型梯度集合中各待处理模型梯度确定预设参数估计值,再根据预设参数估计值确定目标梯度筛选阈值。本发明实施例的技术方案可以基于当前全局模型训练的训练结果确定待处理模型梯度集合,再根据待处理模型梯度集合中各待处理模型梯度确定预设参数估计值,再进一步根据预设参数估计值确定目标梯度筛选阈值,进而动态确定每一次全局模型迭代更新中的梯度筛选阈值,提高梯度筛选效果和模型训练性能。
由于训练过程中大部分的模型梯度的数值很接近于0,丢弃这些模型梯度并不会对模型训练造成显著的影响,而且选出更少的模型梯度也意味着会更好程度的保护样本数据的信息安全性。但是,由于后续需要对筛选出的模型梯度进行加密处理,这会影响后续训练全局模型的精度,因此,需要寻找出一个可以较好保护样本数据的信息安全且对全局模型训练精度影响较小的最优值,基于该最优值对待处理模型梯度进行筛选。
S130、根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合。
其中,加密处理可以是一种对筛选出来的待处理模型梯度进行加密的数据处理方式。具体的,可以采用增加噪声的方式,对筛选出来的待处理模型梯度进行加密。目标模型梯度集合可以是用于更新当前全局模型的梯度的集合。具体的,可以根据目标梯度筛选阈值对待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合。
其中,在根据目标梯度筛选阈值对待处理模型梯度进行筛选的过程中,可以将目标梯度筛选阈值与待处理模型梯度进行对比,并根据对比结果确定筛选出来的待处理模型梯度。
S140、将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型
其中,预设服务端可以是预设的用于更新全局模型的服务端。在获取到目标模型梯度集合后,将目标模型梯度集合发送至预设服务端,以使预设服务端基于目标模型梯度集合对当前全局模型进行更新,得到目标全局模型。目标全局模型可以是经过更新后最终确定的全局模型。具体的,在对当前全局模型进行更新时,可以利用目标模型梯度集合中的模型梯度使当前全局模型沿着梯度下降方向进行更新,实现对当前全局模型的更新。
本发明实施例所提供的技术方案,通过获取预设服务端发送的当前全局模型,并基于预设训练样本对当前全局模型训练的训练结果确定待处理模型梯度集合;基于待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据预设参数估计值确定目标梯度筛选阈值;根据目标梯度筛选阈值对待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;将目标模型梯度集合发送至预设服务端,以使预设服务端基于目标模型梯度集合对当前全局模型进行更新,得到目标全局模型。本发明实施例的技术方案解决了现有基于联邦学习的模型训练技术中梯度筛选阈值固定不变,导致梯度筛选效果和模型训练性能不足的问题,可以动态确定梯度筛选阈值,提高梯度筛选效果和模型训练性能。
图2是本发明实施例提供的又一种基于联邦学习的模型训练方法流程图,本发明实施例可适用于基于联邦学习进行模型训练的场景中,本实施例在上述实施例的基础上,进一步的说明如何基于预设训练样本对当前全局模型训练的训练结果确定待处理模型梯度集合;如何基于待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据预设参数估计值确定目标梯度筛选阈值;以及如何根据目标梯度筛选阈值对待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合。该装置可以由软件和/或硬件的方式来实现,集成于具有应用开发功能的计算机设备中。
为了方便对本发明实施例进行理解,下面对本发明实施例涉及到的背景和原理进行说明。
在本发明实施例中,我们考虑了一个通用的联邦学习(Federated Learning,FL)系统,该系统包含一个参数服务器(Parameter Server,PS)和M个客户端,客户端表示为[M]={1,2,...,M}。每个客户端m都有一个大小为Dm的独立数据集,表示为这些客户端共同合作训练一个全局模型,其训练目标为最小化全局损失函数/>表示为其中,w是模型的训练参数,/>是客户端m的本地损失函数,即每个参与训练的数据样本的损失函数的平均值,表示为其中/>是客户端m在t轮全局迭代中本地选中的训练样本集合。
在本发明实施例中,我们利用差分隐私(Differentially Private,DP)机制来保护客户端梯度的隐私,并且从理论上提供了严格的隐私保证。因此,我们首先介绍差分隐私最常用的定义:
定义1.((∈,δ)-差分隐私)。假设和/>是一对相邻数据集,即/>和/>最多有一个样本不相同,表示为/>随机算法/>满足(∈,δ)-差分隐私当且仅当对于任意的/>和/>以及任意的输出/>有:
其中,表示算法/>的所有可能输出,(∈,δ)表示隐私预算,用于衡量隐私的泄露程度。从该定义中可以看出,隐私预算越小意味着算法/>在相邻数据集上得到相同输出的概率越接近,也就意味着隐私保护效果越好。如果δ=0,则可以表示为∈-差分隐私。
在本发明实施例中,我们使用流行的拉普拉斯机制(Laplace Mechanism,LM)来实现差分隐私,拉普拉斯机制通过产生拉普拉斯随机噪声来扰乱梯度,从而使得攻击者无法利用加噪梯度进行样本攻击。我们以数据集上的查询任务为例子对拉普拉斯机制进行描述。假设表示一个查询结果,其中w表示查询的输入,/>表示查询的数据集。因此,我们可以定义查询结果的l1-敏感度为/>其中和/>是一对相邻数据集,则满足∈-差分隐私的拉普拉斯机制有如下定理:
定理1.(拉普拉斯机制)。给定数据集和查询输入w,满足∈-差分隐私的拉普拉斯机制用拉普拉斯噪声Z扰乱查询结果如下:/>其中Z满足
为了简化表示,噪声Z的分布可以表示为定理1表明了拉普拉斯机制保护一个查询结果消耗的隐私预算为∈。我们可以将FL中客户端的每个梯度当作一个查询结果,拉普拉斯机制为梯度添加噪声以实现差分隐私。在FL过程中,每个客户端参与多轮全局迭代,并且每轮迭代需要与服务器交互多个梯度,因此,我们使用以下定理对消耗的隐私预算进行累加计算:
定理2.(叠加定理)。假设满足∈i-差分隐私。如果算法定义为/>则/>满足/>-差分隐私。
定理2说明了消耗的隐私预算和保护的查询结果数量成正比,当总隐私预算固定时,查询结果数量越多,则每个查询结果分配的隐私预算越小,即DP噪声的方差越大。为了减少噪声的影响,应该适当减少查询结果的数量。为此,稀疏向量技术(Sparse VectorTechnique,SVT)被提出用于只选择部分绝对值较大的查询结果进行响应。假设θ表示阈值,当查询结果qi满足
|qi|+v≥θ+ρ (2)
时,查询结果会被响应。ρ和v都表示拉普拉斯随机噪声,他们被用来保护阈值θ的隐私,分布分别为和/>其中c表示响应的总查询数量。从公式(2)中可以看出,当|qi|的值越大,则其满足响应条件的概率越大。满足响应条件的查询结果还需要添加拉普拉斯噪声/>再进行响应,即返回qi+Z。已经有研究工作证明了:响应c个查询结果的稀疏向量技术满足(∈1+∈2+∈3)-差分隐私。
基于这些背景知识,接下来我们提出稀疏响应的DPFL(Differentially PrivateFederated Learning,差分隐私联邦学习)框架,该框架假设每一轮全局迭代所有的客户端都参与训练。假设wt表示第t轮全局迭代聚合后的全局模型参数,wt[j]表示向量wt中的第j维;表示客户端m在第t轮全局迭代训练得到的本地模型参数;表示客户端m在第t轮全局迭代训练得到的梯度,/>表示向量中的第j维,η表示学习率;/>表示客户端m的梯度的l1-敏感度,其中表示梯度的上界;/>表示客户端m在t轮迭代实际上传的向量,/>表示向量/>中的第j维;∈m表示客户端m保护一个梯度所消耗的隐私预算,根据稀疏向量技术我们将∈m划分为∈m,1,∈m,2和∈m,3,分别用于生成拉普拉斯随机噪声/> 和c表示每个客户端限定上传的总梯度数量;/>表示误差存储向量,用于存储并累积没被选中的梯度,/>表示向量/>中的第j维;Im用于记录客户端m已经上传的梯度数量。有了上述定义,我们可以对稀疏响应的DPFL框架进行描述,其通常包含多轮迭代。
如图2所示,基于联邦学习的模型训练方法包括以下步骤:
S210、根据预设训练样本对当前全局模型进行训练,得到当前模型梯度集合。
预设训练样本可以是预设的用于训练全局模型的样本,具体的,预设训练样本可以是预设客户端上的一些本地数据集合。当前全局模型可以是当前正在进行训练的全局模型。在全局模型的训练过程中会对其进行多次迭代更新,在每一次迭代过程中,可以将在当前正在迭代更新的全局模型作为当前全局模型。进一步的,当前全局模型可以由预设服务端发送至预设客户端,预设客户端接收到当前全局模型后,可以根据预设训练样本对当前全局模型进行训练,得到当前模型梯度集合。其中,当前模型梯度集合可以是对当前全局模型进行训练得到的梯度的集合。
S220、根据所述当前模型梯度集合和所述当前模型梯度集合对应的上一模型误差梯度集合确定待处理模型梯度集合。
其中,待处理模型梯度集合可以是需要进行后续处理并用于更新当前全局模型的模型梯度集合。上一模型误差梯度集合可以是在上一次全局模型迭代更新过程中,没有被筛选中的梯度集合。待处理模型梯度集合可以由当前模型梯度集合和上一模型误差梯度集合进行确定。具体的,可以对两个集合中相同维度的模型梯度进行加合,并将加和后的模型梯度集合作为待处理模型梯度集合。
示例性的,在每一个客户端m利用本地数据集和全局模型wt-1进行本地训练,得到梯度/>加上之前累计的误差,则可以将梯度更新为/> 可以表示待处理模型梯度集合。
S230、根据所述当前全局模型和所述预设训练样本确定训练样本梯度。
其中,训练样本梯度可以是预设训练样本相对于全局模型的梯度。具体的,可以根据当前全局模型和预设训练样本确定训练样本梯度。
示例性的,在第t轮全局迭代中,客户端m会利用从服务器下载的最新全局模型wt-1和本地数据集中的每一个样本ξ计算对应的梯度/>
S240、聚合各所述训练样本梯度得到训练样本梯度聚合参数,并根据所述训练样本梯度聚合参数确定全局估计值。
其中,训练样本梯度聚合参数可以是对训练样本梯度进行聚合得到的参数。训练样本梯度集合参数可以作为确定梯度筛选阈值过程中的一个中间参数。全局估计值可以是对当前全局模型和预设训练样本的相关参数进行估计的数值。全局估计值可以由当前全局模型和预设训练样本确定。
示例性的,聚合所有样本的梯度,表示为其中表示/>中的样本数量。基于上述信息,客户端可以估计其本地参数,分别表示为:1)2)/>3)随后,平均所有客户端的估计值以得到全局估计值,如,其中Dm表示客户端m的样本数量并且/> 表示全局估计值。
S250、聚合各所述待处理模型梯度得到模型梯度聚合参数,并根据所述模型梯度聚合参数和各所述待处理模型梯度的差值确定差值平方上界值。
其中,模型梯度聚合参数可以是对待处理模型梯度集合中各待处理模型梯度进行聚合得到的参数。模型梯度聚合参数可以通过聚合各待处理模型梯度得到。训练样本梯度集合参数可以作为确定梯度筛选阈值过程中的一个中间参数。
差值平方上界值可以是预设的用于确定模型筛选阈值的一个中间参数,该数值可以体现本地梯度和全局梯度的差值平方的上界。具体的,可以根据模型梯度聚合参数和各待处理模型梯度的差值确定差值平方上界值。
示例性,在第t轮全局迭代中,客户端m上传向量给服务器,服务器聚合这些向量得到全局梯度/>并利用/>和at的差异来估计Γ,即其中,Γ表示差值平方上界值。
S260、根据所述全局估计值和所述差值平方上界值确定目标梯度筛选阈值。
其中,由于训练过程中大部分的模型梯度的数值很接近于0,丢弃这些模型梯度并不会对模型训练造成显著的影响,而且选出更少的模型梯度也意味着会更好程度的保护样本数据的信息安全性。但是,由于后续需要对筛选出的模型梯度进行加密处理,这会影响后续训练全局模型的精度,因此,需要寻找出一个可以较好保护样本数据的信息安全且对全局模型训练精度影响较小的最优值,基于该最优值对待处理模型梯度进行筛选。
目标梯度筛选阈值可以是用于对待处理模型梯度进行筛选的参考阈值。具体的,可以根据全局估计值和差值平方上界值确定目标梯度筛选阈值。本发明实施例的技术方案可以确定全局估计值和差值平方上界值,再进一步确定目标梯度筛选阈值,进而动态确定每一次全局模型迭代更新中的梯度筛选阈值,提高梯度筛选效果和模型训练性能。
示例性的,可以根据全局估计值和差值平方上界值确定K值,再对降序排序,并将绝对值第K大的值设置为阈值θ。其中,θ也即目标梯度筛选阈值。
进一步的,确定K值的过程如下所示:
在上述的稀疏响应DPFL框架中,我们特别关注K值的设置,因为K值既决定了每一轮迭代客户端选择的梯度数量,又决定了每一轮迭代客户端的噪声方差。我们需要通过最优化K值,即最优化梯度数量和噪声方差对算法性能的共同影响,来实现最优的稀疏响应。
参考相关的理论分析工作,我们的分析同样对训练模型做了一些常规的假设。我们假设所有客户端的损失函数都是L平滑的,随机梯度和真实梯度的方差上界表示为σ2,随机梯度的方差上界表示为G2。由于联邦学习中客户端的数据存在异构性,因此我们用本地梯度和全局梯度的差值平方上界Γ来量化客户端之间的异构程度,表示为我们的分析基于现有的联邦学习分析工作,但是,稀疏响应DPFL框架多了两个额外部分的误差,分别是梯度选择误差和梯度扰乱误差。
首先,我们分析隐私地选择梯度带来的误差。在该部分,我们用gm来表示客户端m的任意一个梯度向量。根据上述算法描述,只有满足选择条件|gm[j]|+vm≥θm+ρm的梯度会被选中,其中θm是gm中第K大的梯度绝对值。然而,由于vm和ρm都是拉普拉斯随机噪声,我们很难分析每一个梯度gm[j]满足选择条件的概率。因此,我们的分析不单独考虑每个梯度满足选择条件的概率,而从整体上考虑梯度满足选择条件的期望概率,表示为pm,即
pm=Pr[|gm[j]|+vm≥θm+ρm]。 (4)
我们基于拉普拉斯随机噪声ρm和vm的分布,即和推导了pm的上下界,表示为γm≤pm≤ωm,其中d是训练模型的维度。λm的值依赖于∈m,1和∈m,2的大小关系,当2∈m,1≠∈m,2,/>当2∈m,1=∈m,2,/>
假设集合表示客户端m在gm上选中的梯度集合,由于ρm和vm的随机性,/>是一个随机集合。虽然我们无法估计/>中的元素,但是根据pm的定义,我们可以得到/>中元素数量的期望值,表示为/>因此,我们可以得到客户端m的压缩率表示为将压缩率pm代入TopK算法的分析中并利用pm≥γm,我们可以得到客户端m隐私地选择梯度产生的误差上界为/>
接下来,我们分析扰乱选中的梯度带来的误差。在本专利中,我们使用随机噪声来扰乱客户端m选中的每个梯度,根据拉普拉斯分布的性质我们可得Zm的方差为/>根据上述分析,我们可以得到客户端选中的梯度数量的均值为利用噪声的叠加性以及pm≤ωm,我们可以得到客户端m扰乱选中梯度集合/>产生的误差上界,即噪声方差上界,表示为/>
将上述推导得到的梯度选择误差和梯度扰乱误差代入到联邦学习的理论分析框架中,并且将学习率设置为其中T是总迭代轮数,并且满足/>则可以得到稀疏响应的DPFL框架进行了T轮迭代后的收敛结果为:
其中,xT是一个从历史T轮全局模型中随机采样的模型,每个历史模型的采样概率都为/>
从收敛结果我们可以得出以下三个结论:
由于稀疏响应的DPFL框架在随机拉普拉斯噪声ρm和vm的扰乱下概率地选择梯度,因此收敛结果中的梯度选择误差与∈m,1和∈m,2相关(通过γm)。由于该框架用随机拉普拉斯噪声Zm来扰乱选中的梯度,因此收敛结果中的梯度扰乱误差与∈m,1和∈m,2(通过ωm)以及∈m,3相关。
由于λm<1,因此随着K的增加,γm和ωm都会变大。γm变大意味着梯度选择误差会变小,因为有更多的梯度被选择;ωm变大意味着梯度扰乱误差会变大,因为大量的隐私预算会被用于保护不太重要的梯度,这会加剧DP噪声对模型训练的不利影响。
当迭代轮数T趋于无穷时,收敛结果中所有的项都会趋于0除了梯度扰乱误差项。该结果符合先前的DP理论分析工作,因为DP噪声的影响无法被消除。然而,我们可以通过设置较小的K值来减小DP噪声的影响。
上述收敛结果进一步验证了我们的推测,K值的设置应该权衡梯度选择和梯度扰乱两者对模型的共同影响。基于该收敛结果,我们可以通过最小化收敛结果的上界来求解最优K值。然而,我们还需要知道总的全局迭代轮数,也就是所有客户端的隐私都消耗完所需的全局迭代轮数。由于每个客户端最多选择c个梯度,并且客户端m在每一轮全局迭代中选择的梯度数量的期望值为 因此客户端m参与的迭代轮数的期望值下界为/>因此,我们使用该下界作为客户端m参与的全局迭代轮数,即/>因为不同客户端的ωm不同,因此我们使用的ωm平均值来估计整个系统的全局迭代轮数,即其中/>
将代入到收敛结果里,可以得到优化问题为:
其中,B=-32δ2L2G2。将/> 和/>代入到/>中,则优化问题的目标函数转化为:
其中,通过将K的约束从整数放松为实数,我们可以证明/>在取值范围为1≤K≤d之间是一个凸函数。因此,我们可以高效且快速地求解公式(7)的优化问题,从而得到K的最优值。
然而,上述优化问题的求解还需要计算参数L,G2,σ2和Γ的估计值以代入公式(7)中。基于先前的相关工作,我们设计了完整的参数估计方案以便在模型的训练过程中不断计算和更新这些参数的估计值。其中参数L,G2,σ2在客户端本地计算估计值,在服务器聚合得到全局估计值,而Γ则利用客户端上传的实际向量进行估计。
1)估计L,G2,σ2:
在第t轮全局迭代中,客户端m会利用从服务器下载的最新全局模型wt-1和本地数据集中的每一个样本ξ计算对应的梯度/>并聚合所有样本的梯度,表示为其中/>表示/>中的样本数量。基于上述信息,客户端可以估计其本地参数,分别表示为:1)
2)/>3)并将这些估计值上传到服务器。
服务器负责平均所有客户端的估计值以得到全局估计值,如, 其中Dm表示客户端m的样本数量并且
2)估计Γ:
在第t轮全局迭代中,客户端m上传向量给服务器,服务器聚合这些向量得到全局梯度/>并利用/>和at的差异来估计Γ,即/>
基于求解到L,G2,σ2和Γ的估计值,我们可以直接带入公式(7)求解最优K值。
我们的实验证明了,尽管对于极其复杂的卷积神经网络模型,依然只需要10-20轮迭代就可以获得稳定的最优K值,这证明了本发明实施例的稳定性和高效性。
S270、将所述待处理模型梯度与所述目标梯度筛选阈值进行对比,并根据对比结果对所述待处理模型梯度进行筛选。
其中,可以将待处理模型梯度与目标梯度筛选阈值进行对比,并根据对比结果对待处理模型梯度进行筛选,以选取出较为重要的梯度。示例性的,可以对每一个梯度都进行判断,只有通过选择条件且满足Im<c的梯度会被选中。其中,Im表示客户端m当前已经上传的梯度数量,c表示预设的客户端上传的梯度总数。
S280、对筛选出来的各待处理模型梯度添加噪声,并根据各添加过噪声的待处理模型梯度得到所述目标模型梯度集合。
其中,目标模型梯度集合可以是用于更新当前全局模型的梯度的集合。具体的,可以对筛选出来的各待处理模型梯度添加噪声,并根据各添加过噪声的待处理模型梯度得到目标模型梯度集合。通过对筛选出来的待处理模型梯度添加噪声,并将添加过噪声的待处理模型梯度组合成目标梯度集合,可以对模型梯度中的隐私信息进行加密,提高隐私信息的安全性。
示例性的,选中的梯度会以的形式上传给服务器,并且更新Im=Im+1以及/>/>
S290、将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型。
其中,预设服务端可以是预设的用于更新全局模型的服务端。在获取到目标模型梯度集合后,将目标模型梯度集合发送至预设服务端,以使预设服务端基于目标模型梯度集合对当前全局模型进行更新,得到目标全局模型。目标全局模型可以是经过更新后最终确定的全局模型。具体的,在对当前全局模型进行更新时,利用目标模型梯度集合中的模型梯度使当前全局模型沿着梯度下降方向进行更新,实现对当前全局模型的更新。
示例性,服务器收集所有客户端上传的训练结果,表示为随后服务器对/>中满足/>的维度进行聚合并更新全局模型如下所示:
在一种可选的实施方式中,基于没有被筛选中的待处理模型梯度确定辅助模型梯度集合,可以将辅助模型梯度集合发送至预设服务端,以使预设服务端基于辅助模型梯度集合确定当前全局模型的待更新梯度。
其中,辅助模型梯度集合可以是用于辅助预设服务端去确定待更新梯度的集合。具体的,可以基于没有被筛选中的待处理模型梯度组成辅助模型梯度集合。由于全局模型的梯度是多维的,根据辅助模型梯度集合可以确定哪些维度的模型梯度不用进行更新,进而确定需要进行更新的模型梯度,这些需要进行更新的模型梯度也即待更新梯度。示例性的,不被选中的梯度会以的形式上传给服务器,其中⊥表示一个指示符,并且更新/>
在一种可选的实施方式中,在预设服务端基于目标模型梯度集合对当前全局模型进行更新之后,还包括:获取更新后的全局模型;并基于预设训练样本对更新后的全局模型进行迭代训练。在全局模型的训练过程,为了提高全局模型的训练效果,可以对全局模型进行多轮迭代更新。具体的,可以设置对全局模型进行迭代更新的轮数,将确定出目标模型梯度集合,并将其发送至预设服务端作为一次迭代更新的结束标志,随后也会根据客户端发送的目标模型梯度集合中梯度的数量来更新客户端当前已经上传的梯度数量。在完成一次迭代更新后,可将记录的客户端当前已经上传的梯度数量和预设上传的梯度总数进行对比,当记录的已上传的梯度数量小于预设的梯度总数,可以向预设服务端发送指令,以使预设服务端发送更新后的全局模型,并将获取到的更新后的全局模型作为当前全局模型再次开始迭代更新,直至所有客户端记录的已上传的梯度数量均大于等于预设的梯度总数。
本发明实施例所提供的技术方案,通过根据预设训练样本对当前全局模型进行训练,得到当前模型梯度集合;根据当前模型梯度集合和当前模型梯度集合对应的上一模型误差梯度集合确定待处理模型梯度集合;根据当前全局模型和预设训练样本确定训练样本梯度;聚合各训练样本梯度得到训练样本梯度聚合参数,并根据训练样本梯度聚合参数确定全局估计值;聚合各待处理模型梯度得到模型梯度聚合参数,并根据模型梯度聚合参数和各待处理模型梯度的差值确定差值平方上界值;根据全局估计值和差值平方上界值确定目标梯度筛选阈值;将待处理模型梯度与目标梯度筛选阈值进行对比,并根据对比结果对待处理模型梯度进行筛选;对筛选出来的各待处理模型梯度添加噪声,并根据各添加过噪声的待处理模型梯度得到目标模型梯度集合;将目标模型梯度集合发送至预设服务端,以使预设服务端基于目标模型梯度集合对当前全局模型进行更新,得到目标全局模型。本发明实施例的技术方案解决了现有基于联邦学习的模型训练技术中梯度筛选阈值固定不变,导致梯度筛选效果和模型训练性能不足的问题,可以动态确定梯度筛选阈值,提高梯度筛选效果和模型训练性能。
图3是本发明实施例提供的一种基于联邦学习的模型训练装置的结构示意图,本发明实施例可适用于基于联邦学习进行模型训练的场景中,该装置可以由软件和/或硬件的方式来实现,集成于具有应用开发功能的计算机设备中。
如图3所示,基于联邦学习的模型训练装置包括:待处理模型梯度集合确定模块310、目标梯度筛选阈值确定模块320、模型梯度处理模块330和目标模型梯度集合发送模块340。
其中,待处理模型梯度集合确定模块310,用于获取预设服务端发送的当前全局模型,并基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合;目标梯度筛选阈值确定模块320,用于基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值;模型梯度处理模块330,用于根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;目标模型梯度集合发送模块340,用于将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型。
本发明实施例所提供的技术方案,通过获取预设服务端发送的当前全局模型,并基于预设训练样本对当前全局模型训练的训练结果确定待处理模型梯度集合;基于待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据预设参数估计值确定目标梯度筛选阈值;根据目标梯度筛选阈值对待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;将目标模型梯度集合发送至预设服务端,以使预设服务端基于目标模型梯度集合对当前全局模型进行更新,得到目标全局模型。本发明实施例的技术方案解决了现有基于联邦学习的模型训练技术中梯度筛选阈值固定不变,导致梯度筛选效果和模型训练性能不足的问题,可以动态确定梯度筛选阈值,提高梯度筛选效果和模型训练性能。
在一种可选的实施方式中,所述目标梯度筛选阈值确定模块320具体用于:根据所述当前全局模型和所述预设训练样本确定训练样本梯度;聚合各所述训练样本梯度得到训练样本梯度聚合参数,并根据所述训练样本梯度聚合参数确定全局估计值;聚合各所述待处理模型梯度得到模型梯度聚合参数,并根据所述模型梯度聚合参数和各所述待处理模型梯度的差值确定差值平方上界值;根据所述全局估计值和所述差值平方上界值确定所述目标梯度筛选阈值。
在一种可选的实施方式中,所述目标梯度筛选阈值确定模块320具体用于:根据所述全局估计值和所述差值平方上界值确定目标筛选梯度排名值;将各所述待处理模型梯度进行排序,并根据所述目标筛选梯度排名值从排序后的待处理模型梯度中确定所述目标梯度筛选阈值。
在一种可选的实施方式中,所述待处理模型梯度集合确定模块310具体用于:根据所述预设训练样本对所述当前全局模型进行训练,得到当前模型梯度集合;根据所述当前模型梯度集合和所述当前模型梯度集合对应的上一模型误差梯度集合确定所述待处理模型梯度集合。
在一种可选的实施方式中,所述模型梯度处理模块330具体用于:将所述待处理模型梯度与所述目标梯度筛选阈值进行对比,并根据对比结果对所述待处理模型梯度进行筛选;对筛选出来的各待处理模型梯度添加噪声,并根据各添加过噪声的待处理模型梯度得到所述目标模型梯度集合。
在一种可选的实施方式中,所述基于联邦学习的模型训练装置还包括:辅助模型梯度集合发送模块,用于:基于没有被筛选中的待处理模型梯度确定辅助模型梯度集合;将所述辅助模型梯度集合发送至所述预设服务端,以使所述预设服务端基于所述辅助模型梯度集合确定所述当前全局模型的待更新梯度。
在一种可选的实施方式中,所述基于联邦学习的模型训练装置还包括:全局模型迭代训练模块,用于:获取更新后的全局模型;并基于所述预设训练样本对更新后的全局模型进行迭代训练。
本发明实施例所提供的基于联邦学习的模型训练装置可执行本发明任意实施例所提供的基于联邦学习的模型训练方法,具备执行方法相应的功能模块和有益效果。
图4为本发明实施例提供的一种计算机设备的结构示意图。图4示出了适于用来实现本发明实施方式的示例性计算机设备12的框图。图4显示的计算机设备12仅仅是一个示例,不应对本发明实施例的功能和使用范围带来任何限制。计算机设备12可以是任意具有计算能力的终端设备,可以配置于基于联邦学习的模型训练设备中。
如图4所示,计算机设备12以通用计算设备的形式表现。计算机设备12的组件可以包括但不限于:一个或者多个处理器或者处理单元16,系统存储器28,连接不同系统组件(包括系统存储器28和处理单元16)的总线18。
总线18可以是几类总线结构中的一种或多种,包括存储器总线或者存储器控制器,外围总线,图形加速端口,处理器或者使用多种总线结构中的任意总线结构的局域总线。举例来说,这些体系结构包括但不限于工业标准体系结构(ISA)总线,微通道体系结构(MAC)总线,增强型ISA总线、视频电子标准协会(VESA)局域总线以及外围组件互连(PCI)总线。
计算机设备12典型地包括多种计算机系统可读介质。这些介质可以是任何能够被计算机设备12访问的可用介质,包括易失性和非易失性介质,可移动的和不可移动的介质。
系统存储器28可以包括易失性存储器形式的计算机系统可读介质,例如随机存取存储器(RAM)30和/或高速缓存32。计算机设备12可以进一步包括其它可移动/不可移动的、易失性/非易失性计算机系统介质。仅作为举例,存储系统34可以用于读写不可移动的、非易失性磁介质(图4未显示,通常称为“硬盘驱动器”)。尽管图4中未示出,可以提供用于对可移动非易失性磁盘(例如“软盘”)读写的磁盘驱动器,以及对可移动非易失性光盘(例如CD-ROM,DVD-ROM或者其它光介质)读写的光盘驱动器。在这些情况下,每个驱动器可以通过一个或者多个数据介质接口与总线18相连。系统存储器28可以包括至少一个程序产品,该程序产品具有一组(例如至少一个)程序模块,这些程序模块被配置以执行本发明各实施例的功能。
具有一组(至少一个)程序模块42的程序/实用工具40,可以存储在例如系统存储器28中,这样的程序模块42包括但不限于操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。程序模块42通常执行本发明所描述的实施例中的功能和/或方法。
计算机设备12也可以与一个或多个外部设备14(例如键盘、指向设备、显示器24等)通信,还可与一个或者多个使得用户能与该计算机设备12交互的设备通信,和/或与使得该计算机设备12能与一个或多个其它计算设备进行通信的任何设备(例如网卡,调制解调器等等)通信。这种通信可以通过输入/输出(I/O)接口22进行。并且,计算机设备12还可以通过网络适配器20与一个或者多个网络(例如局域网(LAN),广域网(WAN)和/或公共网络,例如因特网)通信。如图所示,网络适配器20通过总线18与计算机设备12的其它模块通信。应当明白,尽管图4中未示出,可以结合计算机设备12使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元、外部磁盘驱动阵列、RAID系统、磁带驱动器以及数据备份存储系统等。
处理单元16通过运行存储在系统存储器28中的程序,从而执行各种功能应用以及数据处理,例如实现本发实施例所提供的一种基于联邦学习的模型训练方法,该方法包括:
获取预设服务端发送的当前全局模型,并基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合;
基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值;
根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;
将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型。
本实施例提供了一种计算机可读介质,其上存储有计算机程序,该程序被处理器执行时实现如本发明任意实施例所提供的一种基于联邦学习的模型训练方法,包括:
获取预设服务端发送的当前全局模型,并基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合;
基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值;
根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;
将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型。
本发明实施例的计算机介质,可以采用一个或多个计算机可读的介质的任意组合。计算机可读介质可以是计算机可读信号介质或者计算机可读介质。计算机可读介质例如可以是但不限于:电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本文件中,计算机可读介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
计算机可读的信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。计算机可读的信号介质还可以是计算机可读介质以外的任何计算机可读介质,该计算机可读介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。
计算机可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于:无线、电线、光缆、RF等等,或者上述的任意合适的组合。
可以以一种或多种程序设计语言或其组合来编写用于执行本发明操作的计算机程序代码,程序设计语言包括面向对象的程序设计语言,诸如Java、Smalltalk、C++,还包括常规的过程式程序设计语言,诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算机上执行、部分地在用户计算机上执行、作为一个独立的软件包执行、部分在用户计算机上部分在远程计算机上执行、或者完全在远程计算机或服务器上执行。在涉及远程计算机的情形中,远程计算机可以通过任意种类的网络,包括局域网(LAN)或广域网(WAN),连接到用户计算机,或者,可以连接到外部计算机(例如利用因特网服务提供商来通过因特网连接)。
本领域普通技术人员应该明白,上述的本发明的各模块或各步骤可以用通用的计算装置来实现,它们可以集中在单个计算装置上,或者分布在多个计算装置所组成的网络上,可选地,他们可以用计算机装置可执行的程序代码来实现,从而可以将它们存储在存储装置中由计算装置来执行,或者将它们分别制作成各个集成电路模块,或者将它们中的多个模块或步骤制作成单个集成电路模块来实现。这样,本发明不限制于任何特定的硬件和软件的结合。
注意,上述仅为本发明的较佳实施例及所运用技术原理。本领域技术人员会理解,本发明不限于这里的特定实施例,对本领域技术人员来说能够进行各种明显的变化、重新调整和替代而不会脱离本发明的保护范围。因此,虽然通过以上实施例对本发明进行了较为详细的说明,但是本发明不仅仅限于以上实施例,在不脱离本发明构思的情况下,还可以包括更多其他等效实施例,而本发明的范围由所附的权利要求范围决定。
Claims (9)
1.一种基于联邦学习的模型训练方法,其特征在于,包括:
获取预设服务端发送的当前全局模型,并基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合;
基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值;
根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;
将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型;
所述基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值,包括:
根据所述当前全局模型和所述预设训练样本确定训练样本梯度;
聚合各所述训练样本梯度得到训练样本梯度聚合参数,并根据所述训练样本梯度聚合参数确定全局估计值;
聚合各所述待处理模型梯度得到模型梯度聚合参数,并根据所述模型梯度聚合参数和各所述待处理模型梯度的差值确定差值平方上界值;
根据所述全局估计值和所述差值平方上界值确定所述目标梯度筛选阈值。
2.根据权利要求1所述的方法,其特征在于,所述根据所述全局估计值和所述差值平方上界值确定所述目标梯度筛选阈值,包括:
根据所述全局估计值和所述差值平方上界值确定目标筛选梯度排名值;
将各所述待处理模型梯度进行排序,并根据所述目标筛选梯度排名值从排序后的待处理模型梯度中确定所述目标梯度筛选阈值。
3.根据权利要求1所述的方法,其特征在于,所述基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合,包括:
根据所述预设训练样本对所述当前全局模型进行训练,得到当前模型梯度集合;
根据所述当前模型梯度集合和所述当前模型梯度集合对应的上一模型误差梯度集合确定所述待处理模型梯度集合。
4.根据权利要求1所述的方法,其特征在于,所述根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合,包括:
将所述待处理模型梯度与所述目标梯度筛选阈值进行对比,并根据对比结果对所述待处理模型梯度进行筛选;
对筛选出来的各待处理模型梯度添加噪声,并根据各添加过噪声的待处理模型梯度得到所述目标模型梯度集合。
5.根据权利要求1所述的方法,其特征在于,所述方法还包括:
基于没有被筛选中的待处理模型梯度确定辅助模型梯度集合;
将所述辅助模型梯度集合发送至所述预设服务端,以使所述预设服务端基于所述辅助模型梯度集合确定所述当前全局模型的待更新梯度。
6.根据权利要求1所述的方法,其特征在于,在所述预设服务端基于所述目标模型梯度集合对当前全局模型进行更新之后,还包括:
获取更新后的全局模型;
并基于所述预设训练样本对更新后的全局模型进行迭代训练。
7.一种基于联邦学习的模型训练装置,其特征在于,所述装置包括:
待处理模型梯度集合确定模块,用于获取预设服务端发送的当前全局模型,并基于预设训练样本对所述当前全局模型训练的训练结果确定待处理模型梯度集合;
目标梯度筛选阈值确定模块,用于基于所述待处理模型梯度集合中的各待处理模型梯度确定预设参数估计值,并根据所述预设参数估计值确定目标梯度筛选阈值;
模型梯度处理模块,用于根据所述目标梯度筛选阈值对所述待处理模型梯度进行筛选,并对筛选出来的各待处理模型梯度进行加密处理,得到目标模型梯度集合;
目标模型梯度集合发送模块,用于将所述目标模型梯度集合发送至预设服务端,以使所述预设服务端基于所述目标模型梯度集合对所述当前全局模型进行更新,得到目标全局模型;
所述目标梯度筛选阈值确定模块具体用于:根据所述当前全局模型和所述预设训练样本确定训练样本梯度;聚合各所述训练样本梯度得到训练样本梯度聚合参数,并根据所述训练样本梯度聚合参数确定全局估计值;聚合各所述待处理模型梯度得到模型梯度聚合参数,并根据所述模型梯度聚合参数和各所述待处理模型梯度的差值确定差值平方上界值;根据所述全局估计值和所述差值平方上界值确定所述目标梯度筛选阈值。
8.一种计算机设备,其特征在于,所述计算机设备包括:
一个或多个处理器;
存储器,用于存储一个或多个程序;
当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如权利要求1-6中任一所述的一种基于联邦学习的模型训练方法。
9.一种计算机可读介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如权利要求1-6中任一所述的一种基于联邦学习的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310870506.7A CN116776155B (zh) | 2023-07-14 | 2023-07-14 | 一种基于联邦学习的模型训练方法、装置、设备和介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310870506.7A CN116776155B (zh) | 2023-07-14 | 2023-07-14 | 一种基于联邦学习的模型训练方法、装置、设备和介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116776155A CN116776155A (zh) | 2023-09-19 |
CN116776155B true CN116776155B (zh) | 2024-03-29 |
Family
ID=88013389
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310870506.7A Active CN116776155B (zh) | 2023-07-14 | 2023-07-14 | 一种基于联邦学习的模型训练方法、装置、设备和介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116776155B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117349672A (zh) * | 2023-10-31 | 2024-01-05 | 深圳大学 | 基于差分隐私联邦学习的模型训练方法、装置及设备 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113591145A (zh) * | 2021-07-28 | 2021-11-02 | 西安电子科技大学 | 基于差分隐私和量化的联邦学习全局模型训练方法 |
CN115527061A (zh) * | 2022-09-07 | 2022-12-27 | 飞马智科信息技术股份有限公司 | 一种基于联邦学习的差分隐私图像分类方法及装置 |
CN115983409A (zh) * | 2022-11-11 | 2023-04-18 | 北京大学 | 基于差分隐私的联邦学习训练方法、装置、系统及设备 |
CN116167084A (zh) * | 2023-02-24 | 2023-05-26 | 北京工业大学 | 一种基于混合策略的联邦学习模型训练隐私保护方法及系统 |
CN116303002A (zh) * | 2023-03-01 | 2023-06-23 | 哈尔滨理工大学 | 基于top-k的通信高效联邦学习的异构软件缺陷预测算法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11941520B2 (en) * | 2020-01-09 | 2024-03-26 | International Business Machines Corporation | Hyperparameter determination for a differentially private federated learning process |
-
2023
- 2023-07-14 CN CN202310870506.7A patent/CN116776155B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113591145A (zh) * | 2021-07-28 | 2021-11-02 | 西安电子科技大学 | 基于差分隐私和量化的联邦学习全局模型训练方法 |
CN115527061A (zh) * | 2022-09-07 | 2022-12-27 | 飞马智科信息技术股份有限公司 | 一种基于联邦学习的差分隐私图像分类方法及装置 |
CN115983409A (zh) * | 2022-11-11 | 2023-04-18 | 北京大学 | 基于差分隐私的联邦学习训练方法、装置、系统及设备 |
CN116167084A (zh) * | 2023-02-24 | 2023-05-26 | 北京工业大学 | 一种基于混合策略的联邦学习模型训练隐私保护方法及系统 |
CN116303002A (zh) * | 2023-03-01 | 2023-06-23 | 哈尔滨理工大学 | 基于top-k的通信高效联邦学习的异构软件缺陷预测算法 |
Non-Patent Citations (3)
Title |
---|
Boosting Accuracy of Differentially Private Federated Learning in Industrial IoT With Sparse Responses;Laizhong Cui et al.;IEEE Transactions on Industrial Informatics;第19卷(第1期);第910-920页 * |
基于秘密分享和梯度选择的高效安全联邦学习;董业;侯炜;陈小军;曾帅;;计算机研究与发展(10);全文 * |
针对分布式联邦深度学习的攻击模型及隐私对策研究;毛耀如;中国优秀硕士学位论文全文数据库信息科技辑;第2021卷(第05期);第I138-114页 * |
Also Published As
Publication number | Publication date |
---|---|
CN116776155A (zh) | 2023-09-19 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11017322B1 (en) | Method and system for federated learning | |
US20200401946A1 (en) | Management and Evaluation of Machine-Learned Models Based on Locally Logged Data | |
CN113127931B (zh) | 基于瑞丽散度进行噪声添加的联邦学习差分隐私保护方法 | |
US20190318268A1 (en) | Distributed machine learning at edge nodes | |
Polzehl et al. | Propagation-separation approach for local likelihood estimation | |
US10110449B2 (en) | Method and system for temporal sampling in evolving network | |
WO2019184640A1 (zh) | 一种指标确定方法及其相关设备 | |
CN113850272A (zh) | 基于本地差分隐私的联邦学习图像分类方法 | |
CN116776155B (zh) | 一种基于联邦学习的模型训练方法、装置、设备和介质 | |
CN114065863A (zh) | 联邦学习的方法、装置、系统、电子设备及存储介质 | |
CN115085196A (zh) | 电力负荷预测值确定方法、装置、设备和计算机可读介质 | |
US11563654B2 (en) | Detection device and detection method | |
Dong et al. | PADP-FedMeta: A personalized and adaptive differentially private federated meta learning mechanism for AIoT | |
CN115296984A (zh) | 异常网络节点的检测方法及装置、设备、存储介质 | |
CN115879152A (zh) | 基于最小均方误差准则的自适应隐私保护方法、装置及系统 | |
Tu et al. | Byzantine-robust distributed sparse learning for M-estimation | |
Sun et al. | Understanding generalization of federated learning via stability: Heterogeneity matters | |
CN113098624B (zh) | 量子态测量方法、装置、设备、存储介质及系统 | |
CN114185860A (zh) | 抗合谋攻击的数据共享方法、装置和电子设备 | |
CN115965093A (zh) | 模型训练方法、装置、存储介质及电子设备 | |
CN117057445A (zh) | 基于联邦学习框架的模型优化方法、系统和装置 | |
CN114118381B (zh) | 基于自适应聚合稀疏通信的学习方法、装置、设备及介质 | |
CN113094751B (zh) | 一种个性化隐私数据处理方法、装置、介质及计算机设备 | |
US20220188682A1 (en) | Readout-error mitigation for quantum expectation | |
US10832466B2 (en) | View-dependent stochastic volume rendering with Monte Carlo ray exploration |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |