CN113158550B - 一种联邦学习方法、装置、电子设备及存储介质 - Google Patents

一种联邦学习方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN113158550B
CN113158550B CN202110314849.6A CN202110314849A CN113158550B CN 113158550 B CN113158550 B CN 113158550B CN 202110314849 A CN202110314849 A CN 202110314849A CN 113158550 B CN113158550 B CN 113158550B
Authority
CN
China
Prior art keywords
terminal
model
accuracy
local
terminals
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110314849.6A
Other languages
English (en)
Other versions
CN113158550A (zh
Inventor
高志鹏
邱晨豪
杨杨
张瀛瀚
赵晨
莫梓嘉
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing University of Posts and Telecommunications
Original Assignee
Beijing University of Posts and Telecommunications
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing University of Posts and Telecommunications filed Critical Beijing University of Posts and Telecommunications
Priority to CN202110314849.6A priority Critical patent/CN113158550B/zh
Publication of CN113158550A publication Critical patent/CN113158550A/zh
Application granted granted Critical
Publication of CN113158550B publication Critical patent/CN113158550B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F30/00Computer-aided design [CAD]
    • G06F30/20Design optimisation, verification or simulation
    • G06F30/27Design optimisation, verification or simulation using machine learning, e.g. artificial intelligence, neural networks, support vector machines [SVM] or training a model
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Abstract

本申请实施例提供的联邦学习方法、装置、电子设备及存储介质,应用于深度神经网络模型训练的技术领域,通过初始化并接收各终端发送的共享验证数据和本地模型的模型参数;根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重;根据各终端的权重和各终端发送的本地模型的模型参数进行模型聚合;将模型聚合得到的全局模型的模型参数发送至各终端,以使各终端根据接收到的全局模型的模型参数对本地模型的参数进行更新。从而通过根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重,根据计算得到的权重进行模型聚合,从而可以考虑各个终端的差异,避免由于数据分布不均,导致的计算效率下降的问题,提高计算的效率。

Description

一种联邦学习方法、装置、电子设备及存储介质
技术领域
本申请涉及深度神经网络模型训练的技术领域,特别是涉及一种联邦学习方法、装置、电子设备及存储介质。
背景技术
目前,联邦学习算法的深度神经网络模型训练领域已经应用的越来越广泛。通过联邦学习算法可以在不上传终端训练节点本地原始数据情况下,完成全局模型训练,在训练过程中,终端设备将本地更新的模型参数发送到服务器进行聚合,然后中央服务器对这些参数聚合后重新下发更新终端节点模型完成一轮次训练。
然而,当前的联邦学习算法在进行聚合过程中,往往是通过平均的方法,将终端返回参数进行平均之后得到全局模型。而在实际训练过程中,经常面临终端节点数据非均匀分布的情况,通过平均的方法往往会导致模型不易收敛或收敛时间较长,从而导致计算效率下降。
发明内容
本申请实施例的目的在于提供一种联邦学习方法、装置、电子设备及存储介质,用以解决联邦学习过程中计算效率下降的问题。具体技术方案如下:
本申请实施例的第一方面,首先提供了一种联邦学习方法,应用于服务器,所述服务器用于管理至少两个终端,所述方法包括:
初始化并接收各所述终端发送的共享验证数据和本地模型的模型参数,其中,每一终端均运行一本地模型;
根据各所述终端发送的所述共享验证数据和所述本地模型的模型参数计算各终端的权重;
根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数进行模型聚合;
将模型聚合得到的全局模型的模型参数发送至各所述终端,以使各所述终端根据接收到的所述全局模型的模型参数对所述本地模型的参数进行更新。
可选的,所述初始化并接收各终端发送的共享验证数据和本地模型的模型参数,包括:
初始化所述全局模型;
将初始化的全局模型发送至各所述终端;
接收各所述终端发送的所述本地数据并进行保存;
接收各所述终端发送的所述共享验证数据和所述本地模型的模型参数。
可选的,所述根据各所述终端发送的所述共享验证数据和所述本地模型的模型参数计算各终端的权重,包括:
根据各所述终端发送的所述共享验证数据计算各所述终端的准确率;
获取各所述终端的历史准确率,并根据各所述终端的准确率和各所述终端的历史准确率,计算各所述终端的准确率平均值;
根据各所述终端的准确率和各所述终端的准确率平均值,计算各所述终端的准确率进步值;
根据各所述终端的准确率进步值计算各所述终端的权重。
可选的,所述根据各所述终端发送的所述共享验证数据计算各所述终端的准确率,包括:
根据各所述终端发送的所述共享验证数据,通过预设公式:
Figure BDA0002990757060000021
计算各所述终端的准确率,其中,所述共享验证数据中包括测试数据xi
Figure BDA0002990757060000022
代表测试数据xi在当前模型下的预测结果,
Figure BDA0002990757060000023
代表当前模型对共享验证数据中测试数据xi的预测结果是否与对应数据的标签yi一致,一致为1,不一致取值为0,number(Dtest)为测试数据Dtest的条数,
Figure BDA0002990757060000024
为n个终端在t轮次的准确率;
所述获取各所述终端的历史准确率,并根据各所述终端的准确率和各所述终端的历史准确率,计算各所述终端的准确率平均值,包括:
通过预设公式:
Figure BDA0002990757060000031
计算各终端的历史准确率,其中,
Figure BDA0002990757060000032
为n个终端在t轮次的历史准确率,N为参与的终端的数量,m为截尾系数,
Figure BDA0002990757060000033
为第i个终端在t轮次的准确率;
根据各所述终端的准确率和各终端的历史准确率,通过预设公式:
Figure BDA0002990757060000034
计算各所述终端的准确率平均值,其中,
Figure BDA0002990757060000035
为n个终端在t轮次的准确率平均值,
Figure BDA0002990757060000036
为n个终端在t-1轮次的历史准确率;
所述根据各所述终端的准确率和各所述终端的准确率平均值,计算各所述终端的准确率进步值,包括:
根据各所述终端的准确率和各所述终端的准确率平均值,通过预设公式:
Figure BDA0002990757060000037
计算各所述终端的准确率进步值,其中,Rn为准确率进步值,σt表示终端是否参与t轮次的训练,参与为1,不参与为0,Tmax为训练总轮次数,
Figure BDA0002990757060000038
为阶跃函数,若t轮次节点参数准确率高于节点准确率平均值,则为1,否则为0;
所述根据各所述终端的准确率进步值计算各所述终端的权重,包括:
根据各所述终端的准确率进步值,通过预设公式:
Figure BDA0002990757060000041
计算各所述终端的权重,其中,γi为终端训练趋势度量系数,γi=a*b,
Figure BDA0002990757060000042
为第i个终端在t轮次相对于全局模型的准确率进步值。
可选的,所述根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数进行模型聚合,包括:
根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数,通过预设公式:
Figure BDA0002990757060000043
进行模型聚合,其中,
Figure BDA0002990757060000044
为终端i在轮次t提交的本地模型,
Figure BDA0002990757060000045
为节点i在t轮次的节点权重,
Figure BDA0002990757060000046
为节点i在轮次t提交的本地模型。
本申请实施例的第二方面,还提供了一种联邦学习装置,应用于服务器,所述服务器用于管理至少两个终端,所述装置包括:
参数接收模块,用于初始化并接收各所述终端发送的共享验证数据和本地模型的模型参数,其中,每一终端均运行一本地模型;
权重计算模块,用于根据各所述终端发送的所述共享验证数据和所述本地模型的模型参数计算各终端的权重;
模型聚合模块,用于根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数进行模型聚合;
参数更新模块,用于将模型聚合得到的全局模型的模型参数发送至各所述终端,以使各所述终端根据接收到的所述全局模型的模型参数对所述本地模型的参数进行更新。
可选的,所述参数接收模块,包括:
初始化子模块,用于初始化所述全局模型;
模型发送子模块,用于将初始化的全局模型发送至各所述终端;
数据保存子模块,用于接收各所述终端发送的所述本地数据并进行保存;
参数接收子模块,用于接收各所述终端发送的所述共享验证数据和所述本地模型的模型参数。
可选的,所述权重计算模块,包括:
准确率计算子模块,用于根据各所述终端发送的所述共享验证数据计算各所述终端的准确率;
平均值计算子模块,用于获取各所述终端的历史准确率,并根据各所述终端的准确率和各所述终端的历史准确率,计算各所述终端的准确率平均值;
进步值计算子模块,用于根据各所述终端的准确率和各所述终端的准确率平均值,计算各所述终端的准确率进步值;
权重计算子模块,用于根据权重各所述终端的准确率进步值计算各所述终端的权重。
可选的,所述准确率计算子模块,具体用于:
根据各所述终端发送的所述共享验证数据,通过预设公式:
Figure BDA0002990757060000051
计算各所述终端的准确率,其中,所述共享验证数据中包括测试数据xi
Figure BDA0002990757060000052
代表测试数据xi在当前模型下的预测结果,
Figure BDA0002990757060000053
代表当前模型对共享验证数据中测试数据xi的预测结果是否与对应数据的标签yi一致,一致为1,不一致取值为0,number(Dtest)为测试数据Dtest的条数,
Figure BDA0002990757060000054
为n个终端在t轮次的准确率;
所述平均值计算子模块,具体用于:
通过预设公式:
Figure BDA0002990757060000061
计算各终端的历史准确率,其中,
Figure BDA0002990757060000062
为n个终端在t轮次的历史准确率,N为参与的终端的数量,m为截尾系数,
Figure BDA0002990757060000063
为第i个终端在t轮次的准确率;
根据各所述终端的准确率和各终端的历史准确率,通过预设公式:
Figure BDA0002990757060000064
计算各所述终端的准确率平均值,其中,
Figure BDA0002990757060000065
为n个终端在t轮次的准确率平均值,
Figure BDA0002990757060000066
为n个终端在t-1轮次的历史准确率;
所述进步值计算子模块,具体用于:
根据各所述终端的准确率和各所述终端的准确率平均值,通过预设公式:
Figure BDA0002990757060000067
计算各所述终端的准确率进步值,其中,Rn为准确率进步值,σt表示终端是否参与t轮次的训练,参与为1,不参与为0,Tmax为训练总轮次数,
Figure BDA0002990757060000068
为阶跃函数,若t轮次节点参数准确率高于节点准确率平均值,则为1,否则为0;
所述权重计算子模块,具体用于:
根据各所述终端的准确率进步值,通过预设公式:
Figure BDA0002990757060000069
计算各所述终端的权重,其中,γi为终端训练趋势度量系数,γi=a*b,
Figure BDA00029907570600000610
为第i个终端在t轮次相对于全局模型的准确率进步值。
可选的,所述模型聚合模块,具体用于:
根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数,通过预设公式:
Figure BDA0002990757060000071
进行模型聚合,其中,
Figure BDA0002990757060000072
为终端i在轮次t提交的本地模型,
Figure BDA0002990757060000073
为节点i在t轮次的节点权重,
Figure BDA0002990757060000074
为节点i在轮次t提交的本地模型。
本申请实施例还提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现上述任一联邦学习方法。
本申请实施例还提供了一种计算机可读存储介质,计算机可读存储介质内存储有计算机程序,计算机程序被处理器执行时实现上述任一联邦学习方法。
本申请实施例还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述任一联邦学习方法。
本申请实施例有益效果:
本申请实施例提供的联邦学习方法、装置、电子设备及存储介质,通过初始化并接收各终端发送的共享验证数据和本地模型的模型参数;根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重;根据各终端的权重和各终端发送的本地模型的模型参数进行模型聚合;将模型聚合得到的全局模型的模型参数发送至各终端,以使各终端根据接收到的全局模型的模型参数对本地模型的参数进行更新。从而通过根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重,根据计算得到的权重进行模型聚合,从而可以考虑各个终端的差异,避免由于数据分布不均,导致的计算效率下降的问题,提高计算的效率。
当然,实施本申请的任一产品或方法并不一定需要同时达到以上所述的所有优点。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的实施例。
图1为本申请实施例提供的联邦学习方法的一种流程示意图;
图2为本申请实施例提供的全局模型初始化的一种流程示意图;
图3为本申请实施例提供的通过本地模型对全局模型进行参数更新的流程示意图;
图4为本申请实施例提供的通过全局模型对本地模型进行参数更新的流程示意图;
图5为本申请实施例提供的计算各终端的权重的一种流程示意图;
图6为本申请实施例提供的一种联邦学习系统的结构示意图;
图7为本申请实施例提供的联邦学习装置的一种结构示意图;
图8为本申请实施例提供的电子设备的一种结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员基于本申请所获得的所有其他实施例,都属于本申请保护的范围。
随着深度神经网络的模型的不断发展和使用场景对于隐私的重视程度的日益上升,联邦学习逐渐在越来越多的场景下发挥作用。联邦学习天然具有数据保存在终端本地不进行上传的优势,可以有效保护数据隐私,在训练过程中终端设备将本地更新的模型参数发送到服务器进行聚合,中央服务器对这些参数聚合后重新下发更新终端模型完成轮次训练。在传统的联邦学习模型中,通常使用联邦平均作为参数聚合方式,即将终端返回参数进行平均之后得到全局模型,并返回终端节点进行下一轮次的训练。
基于联邦平均的训练模式适用于终端节点数据均匀分布且终端算力相对平均的情况,在实际训练过程中,经常面临数据非均匀分布,终端算力不均的问题,同时也存在终端非诚实或是存在恶意会对训练过程进行干扰的情况。传统的联邦平均聚合方式在此情境下无法高效完成模型的训练,模型不易收敛或收敛时间较长。
为了解决以上问题,本申请实施例提供了一种联邦学习方法,应用于服务器,服务器用于管理至少两个终端,上述方法包括:
初始化并接收各终端发送的共享验证数据和本地模型的模型参数,其中,每一终端均运行一本地模型;
根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重;
根据各终端的权重和各终端发送的本地模型的模型参数进行模型聚合;
将模型聚合得到的全局模型的模型参数发送至各终端,以使各终端根据接收到的全局模型的模型参数对本地模型的参数进行更新。
可见,通过本申请实施例的联邦学习方法,可以通过根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重,根据计算得到的权重进行模型聚合,从而可以考虑各个终端的差异,避免由于数据分布不均,导致的计算效率下降的问题,提高计算的效率。
参见图1,图1为本申请实施例提供的联邦学习方法的一种流程示意图,上述方法应用于服务器,服务器用于管理至少两个终端,上述方法包括:
步骤S11,初始化并接收各终端发送的共享验证数据和本地模型的模型参数。
其中,每一终端均运行一本地模型。在实际使用过程中,上述终端可以是手机、物联网设备、传感器等带有较弱计算能力的智能终端,上述服务器为计算能力较强的服务器,如边缘服务器等,上述服务器负责接收终端节点传回的各个本地模型参数并进行联邦聚合和模型下发过程。上述本地模型可以是人脸识别模型等网络模型。通过各终端手机的数据,可以对本地模型进行训练。
可选的,初始化并接收各终端发送的共享验证数据和本地模型的模型参数,包括:初始化全局模型;将初始化的全局模型发送至各终端;接收各终端发送的本地数据并进行保存;接收各终端发送的共享验证数据和本地模型的模型参数。
参见图2,图2为本申请实施例提供的全局网络模型初始化的一种流程示意图,其中,中央服务器通过模型初始化得到全局模型,并将全局模型分发至终端节点,终端节点通过本地数据对本地模型进行训练,并共享验证集上传至中央服务器。
步骤S12,根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重。
其中,上述共享验证数据可以是终端采集到的数据,在初始化过程中向服务器进行发送。本地模型的模型参数可以包括本地模型的结构、所配置的变量等模型参数。
根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重可以通过计算各终端的准确率、可靠性等,然后根据准确率、可靠性等计算各个终端的权重。
步骤S13,根据各终端的权重和各终端发送的本地模型的模型参数进行模型聚合。
根据各终端的权重和各终端发送的本地模型的模型参数进行模型聚合,可以为根据步骤S12计算到的各终端的权重,以及各终端的发送的本地模型的结构、所配置的变量等参数,进行模型聚合。参见图3,本地模块通过节点参数验证模块基于验证集的预验证和历史信息记录模块的全局模型准确率Lg、节点准确率列表La以及及诶单交互频率Lf反馈至联邦聚合模块,联邦聚合模块通过加权聚合生成全局模型。
可选的,根据各终端的权重和各终端发送的本地模型的模型参数进行模型聚合,包括:
根据各终端的权重和各终端发送的本地模型的模型参数,通过预设公式:
Figure BDA0002990757060000111
进行模型聚合,其中,
Figure BDA0002990757060000112
为终端i在轮次t提交的本地模型,
Figure BDA0002990757060000113
为节点i在t轮次的节点权重,
Figure BDA0002990757060000114
为节点i在轮次t提交的本地模型。
步骤S14,将模型聚合得到的全局模型的模型参数发送至各终端,以使各终端根据接收到的全局模型的模型参数对本地模型的参数进行更新。
其中,将模型聚合得到的全局模型的模型参数发送至各终端,可以将全局模型的结构、所配置的变量等参数发送至各终端。其中,上述本地模型和全局模型可以为相同类型的模型。
例如,参见图4,终端节点通过本地数据训练本地模块,并将本地模块的参数上传至中央服务器,中央服务器通过联邦聚合得到全局模型,并根据全局模型更新本地模型。
可见,通过本申请实施例的联邦学习方法,可以通过根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重,根据计算得到的权重进行模型聚合,从而可以考虑各个终端的差异,避免由于数据分布不均,导致的计算效率下降的问题,提高计算的效率。
可选的,参见图5,步骤S12根据各终端发送的本地数据和预先存储的各终端的历史数据计算各终端的权重,包括:
步骤S121,根据各终端发送的共享验证数据计算各终端的准确率;
步骤S122,获取各终端的历史准确率,并根据各终端的准确率和各终端的历史准确率,计算各终端的准确率平均值;
步骤S123,根据各终端的准确率和各终端的准确率平均值,计算各终端的准确率进步值;
步骤S124,根据各终端的准确率进步值计算各终端的权重。
可选的,根据各终端发送的共享验证数据计算各终端的准确率,包括:
根据各终端发送的共享验证数据,通过预设公式:
Figure BDA0002990757060000121
计算各终端的准确率,其中,所述共享验证数据中包括测试数据xi
Figure BDA0002990757060000122
代表测试数据xi在当前模型下的预测结果,
Figure BDA0002990757060000123
代表当前模型对共享验证数据中测试数据xi的预测结果是否与对应数据的标签yi一致,一致为1,不一致取值为0,number(Dtest)为测试数据Dtest的条数,
Figure BDA0002990757060000124
为n个终端在t轮次的准确率;
获取各终端的历史准确率,并根据各终端的准确率和各终端的历史准确率,计算各终端的准确率平均值,包括:
通过预设公式:
Figure BDA0002990757060000125
计算各终端的历史准确率,其中,
Figure BDA0002990757060000126
为n个终端在t轮次的历史准确率,N为参与的终端的数量,m为截尾系数,
Figure BDA0002990757060000127
为第i个终端在t轮次的准确率;
根据各终端的准确率和各终端的历史准确率,通过预设公式:
Figure BDA0002990757060000128
计算各终端的准确率平均值,其中,
Figure BDA0002990757060000129
为n个终端在t轮次的准确率平均值,
Figure BDA00029907570600001210
为n个终端在t-1轮次的历史准确率;
根据各终端的准确率和各终端的准确率平均值,计算各终端的准确率进步值,包括:
根据各终端的准确率和各终端的准确率平均值,通过预设公式:
Figure BDA0002990757060000131
计算各终端的准确率进步值,其中,Rn为准确率进步值,σt表示终端是否参与t轮次的训练,参与为1,不参与为0,Tmax为训练总轮次数,
Figure BDA0002990757060000132
为阶跃函数,若t轮次节点参数准确率高于节点准确率平均值,则为1,否则为0;
根据各终端的准确率进步值计算各终端的权重,包括:
根据各终端的准确率进步值,通过预设公式:
Figure BDA0002990757060000133
计算各终端的权重,其中,γi为终端训练趋势度量系数,γi=a*b,
Figure BDA0002990757060000134
为第i个终端在t轮次相对于全局模型的准确率进步值。
参见图6,图6为本申请实施例提供的一种联邦学习系统的结构示意图,包括:中央服务器和终端训练节点。
终端训练节点,带有较弱计算能力的智能终端,可以为手机、物联网设备、传感器等。
中央服务器,为计算能力较强的服务器(如边缘服务器),负责接收终端节点传回的各个本地模型参数并进行联邦聚合和模型下发过程。
中央服务器包括验证模块,验证模块会在联邦学习初始化期间,收集各参与节点提交的公共验证数据集。联邦学习训练期间,终端训练节点使用本地数据进行模型更新并将更新后的模型参数上传到中央服务器的节点参数验证模块。验证模块使用验证数据集对节点的参数进行验证,确定节点参数可靠性。
验证模块进行验证后将信息提交给节点记录模块,节点行为记录模块负责记录联邦学习训练过程中节点参数历史准确率和全局模型准确率以及节点参与训练记录,中央服务器根据当前轮次的评估结果和节点历史数据确定节点参数权重,并进行参数聚合工作。
本申请实施例的第二方面,还提供了一种联邦学习装置,应用于服务器,服务器用于管理至少两个终端,参见图7,图7为本申请实施例提供的联邦学习装置的一种结构示意图,上述装置包括:
参数接收模块701,用于初始化并接收各终端发送的共享验证数据和本地模型的模型参数,其中,每一终端均运行一本地模型;
权重计算模块702,用于根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重;
模型聚合模块703,用于根据各终端的权重和各终端发送的本地模型的模型参数进行模型聚合;
参数更新模块704,用于将模型聚合得到的全局模型的模型参数发送至各终端,以使各终端根据接收到的全局模型的模型参数对本地模型的参数进行更新。
可选的,参数接收模块701,包括:
初始化子模块,用于初始化全局模型;
模型发送子模块,用于将初始化的全局模型发送至各终端;
数据保存子模块,用于接收各终端发送的本地数据并进行保存;
参数接收子模块,用于接收各终端发送的共享验证数据和本地模型的模型参数。
可选的,权重计算模块702,包括:
准确率计算子模块,用于根据各终端发送的共享验证数据计算各终端的准确率;
平均值计算子模块,用于获取各终端的历史准确率,并根据各终端的准确率和各终端的历史准确率,计算各终端的准确率平均值;
进步值计算子模块,用于根据各终端的准确率和各终端的准确率平均值,计算各终端的准确率进步值;
权重计算子模块,用于根据权重各终端的准确率进步值计算各终端的权重。
可选的,准确率计算子模块,具体用于:
根据各终端发送的共享验证数据,通过预设公式:
Figure BDA0002990757060000151
计算各终端的准确率,其中,所述共享验证数据中包括测试数据xi
Figure BDA0002990757060000152
代表测试数据xi在当前模型下的预测结果,
Figure BDA0002990757060000153
代表当前模型对共享验证数据中测试数据xi的预测结果是否与对应数据的标签yi一致,一致为1,不一致取值为0,number(Dtest)为测试数据Dtest的条数,
Figure BDA0002990757060000154
为n个终端在t轮次的准确率;
平均值计算子模块,具体用于:
通过预设公式:
Figure BDA0002990757060000155
计算各终端的历史准确率,其中,
Figure BDA0002990757060000156
为n个终端在t轮次的历史准确率,N为参与的终端的数量,m为截尾系数,
Figure BDA0002990757060000157
为第i个终端在t轮次的准确率;
根据各终端的准确率和各终端的历史准确率,通过预设公式:
Figure BDA0002990757060000158
计算各终端的准确率平均值,其中,
Figure BDA0002990757060000159
为n个终端在t轮次的准确率平均值,
Figure BDA00029907570600001510
为n个终端在t-1轮次的历史准确率;
进步值计算子模块,具体用于:
根据各终端的准确率和各终端的准确率平均值,通过预设公式:
Figure BDA0002990757060000161
计算各终端的准确率进步值,其中,Rn为准确率进步值,σt表示终端是否参与t轮次的训练,参与为1,不参与为0,Tmax为训练总轮次数,
Figure BDA0002990757060000162
为阶跃函数,若t轮次节点参数准确率高于节点准确率平均值,则为1,否则为0;
权重计算子模块,具体用于:
根据各终端的准确率进步值,通过预设公式:
Figure BDA0002990757060000163
计算各终端的权重,其中,γi为终端训练趋势度量系数,γi=a*b,
Figure BDA0002990757060000164
为第i个终端在t轮次相对于全局模型的准确率进步值。
可选的,模型聚合模块703,具体用于:
根据各终端的权重和各终端发送的本地模型的模型参数,通过预设公式:
Figure BDA0002990757060000165
进行模型聚合,其中,
Figure BDA0002990757060000166
为终端i在轮次t提交的本地模型,
Figure BDA0002990757060000167
为节点i在t轮次的节点权重,
Figure BDA0002990757060000168
为节点i在轮次t提交的本地模型。
可见,通过本申请实施例的联邦学习方法,可以通过根据各终端发送的共享验证数据和本地模型的模型参数计算各终端的权重,根据计算得到的权重进行模型聚合,从而可以考虑各个终端的差异,避免由于数据分布不均,导致的计算效率下降的问题,提高计算的效率。
本申请实施例还提供了一种电子设备,如图8所示,包括处理器801、通信接口802、存储器803和通信总线804,其中,处理器801,通信接口802,存储器803通过通信总线804完成相互间的通信,
存储器803,用于存放计算机程序;
处理器801,用于执行存储器803上所存放的程序时,实现如下步骤:
接收各终端发送的本地数据和本地模型的模型参数,其中,每一终端均运行一本地模型,通过本地数据对本地模型进行训练;
根据各终端发送的本地数据和预先存储的各终端的历史数据计算各终端的权重;
根据各终端的权重和各终端发送的本地模型的模型参数进行模型聚合;
将模型聚合得到的全局模型的模型参数发送至各终端,以使各终端根据接收到的全局模型的模型参数对本地模型的参数进行更新。
上述电子设备提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口用于上述电子设备与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non-Volatile Memory,NVM),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(Network Processor,NP)等;还可以是数字信号处理器(Digital SignalProcessor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
在本申请提供的又一实施例中,还提供了一种计算机可读存储介质,该计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现上述任一联邦学习方法的步骤。
在本申请提供的又一实施例中,还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述实施例中任一联邦学习方法。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘Solid State Disk(SSD))等。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置、电子设备、存储介质、计算机程序产品实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本申请的较佳实施例,并非用于限定本申请的保护范围。凡在本申请的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本申请的保护范围内。

Claims (7)

1.一种联邦学习方法,其特征在于,应用于服务器,所述服务器用于管理至少两个终端,所述方法包括:
初始化并接收各所述终端发送的共享验证数据和本地模型的模型参数,其中,每一终端均运行一本地模型;
根据各所述终端发送的所述共享验证数据,通过预设公式:
Figure FDA0003732073550000011
计算各所述终端的准确率,其中,所述共享验证数据中包括测试数据xi
Figure FDA0003732073550000019
代表测试数据xi在当前模型下的预测结果,
Figure FDA00037320735500000110
代表当前模型对共享验证数据中测试数据xi的预测结果是否与对应数据的标签yi一致,一致为1,不一致取值为0,number(Dtest)为测试数据Dtest的条数,
Figure FDA0003732073550000012
为n个终端在t轮次的准确率;
通过预设公式:
Figure FDA0003732073550000013
计算各终端的历史准确率,其中,
Figure FDA0003732073550000014
为n个终端在t轮次的历史准确率,N为参与的终端的数量,m为截尾系数,
Figure FDA0003732073550000015
为第i个终端在t轮次的准确率;
根据各所述终端的准确率和各终端的历史准确率,通过预设公式:
Figure FDA0003732073550000016
计算各所述终端的准确率平均值,其中,
Figure FDA0003732073550000017
为n个终端在t轮次的准确率平均值,
Figure FDA0003732073550000018
为n个终端在t-1轮次的历史准确率;
根据各所述终端的准确率和各所述终端的准确率平均值,通过预设公式:
Figure FDA0003732073550000021
计算各所述终端的准确率进步值,其中,Rn为准确率进步值,σt表示终端是否参与t轮次的训练,参与为1,不参与为0,Tmax为训练总轮次数,
Figure FDA0003732073550000022
Figure FDA0003732073550000023
为阶跃函数,若t轮次节点参数准确率高于节点准确率平均值,则为1,否则为0;
根据各所述终端的准确率进步值,通过预设公式:
Figure FDA0003732073550000024
计算各所述终端的权重,其中,γi为终端训练趋势度量系数,
Figure FDA0003732073550000025
为第i个终端在t轮次相对于全局模型的准确率进步值;
根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数进行模型聚合;
将模型聚合得到的全局模型的模型参数发送至各所述终端,以使各所述终端根据接收到的所述全局模型的模型参数对所述本地模型的参数进行更新。
2.根据权利要求1所述的方法,其特征在于,所述初始化并接收各终端发送的共享验证数据和本地模型的模型参数,包括:
初始化所述全局模型;
将初始化的全局模型发送至各所述终端;
接收各所述终端发送的本地数据并进行保存;
接收各所述终端发送的所述共享验证数据和所述本地模型的模型参数。
3.根据权利要求1所述的方法,其特征在于,所述根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数进行模型聚合,包括:
根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数,通过预设公式:
Figure FDA0003732073550000031
进行模型聚合,其中,
Figure FDA0003732073550000032
为终端i在轮次t提交的本地模型,
Figure FDA0003732073550000033
为节点i在t轮次的节点权重,
Figure FDA0003732073550000034
为节点i在轮次t提交的本地模型。
4.一种联邦学习装置,其特征在于,应用于服务器,所述服务器用于管理至少两个终端,所述装置包括:
参数接收模块,用于初始化并接收各所述终端发送的共享验证数据和本地模型的模型参数,其中,每一终端均运行一本地模型;
权重计算模块,包括:准确率计算子模块、平均值计算子模块、进步值计算子模块、权重计算子模块;
所述准确率计算子模块,用于:
根据各所述终端发送的所述共享验证数据,通过预设公式:
Figure FDA0003732073550000035
计算各所述终端的准确率,其中,所述共享验证数据中包括测试数据xi
Figure FDA0003732073550000036
代表测试数据xi在当前模型下的预测结果,
Figure FDA0003732073550000037
代表当前模型对共享验证数据中测试数据xi的预测结果是否与对应数据的标签yi一致,一致为1,不一致取值为0,number(Dtest)为测试数据Dtest的条数,
Figure FDA0003732073550000038
为n个终端在t轮次的准确率;
所述平均值计算子模块,用于:
通过预设公式:
Figure FDA0003732073550000041
计算各终端的历史准确率,其中,
Figure FDA0003732073550000042
为n个终端在t轮次的历史准确率,N为参与的终端的数量,m为截尾系数,
Figure FDA0003732073550000043
为第i个终端在t轮次的准确率;
根据各所述终端的准确率和各终端的历史准确率,通过预设公式:
Figure FDA0003732073550000044
计算各所述终端的准确率平均值,其中,
Figure FDA0003732073550000045
为n个终端在t轮次的准确率平均值,
Figure FDA0003732073550000046
为n个终端在t-1轮次的历史准确率;
所述进步值计算子模块,用于:
根据各所述终端的准确率和各所述终端的准确率平均值,通过预设公式:
Figure FDA0003732073550000047
计算各所述终端的准确率进步值,其中,Rn为准确率进步值,σt表示终端是否参与t轮次的训练,参与为1,不参与为0,Tmax为训练总轮次数,
Figure FDA0003732073550000048
Figure FDA0003732073550000049
为阶跃函数,若t轮次节点参数准确率高于节点准确率平均值,则为1,否则为0;
所述权重计算子模块,用于:
根据各所述终端的准确率进步值,通过预设公式:
Figure FDA00037320735500000410
计算各所述终端的权重,其中,γi为终端训练趋势度量系数,
Figure FDA00037320735500000411
为第i个终端在t轮次相对于全局模型的准确率进步值;
模型聚合模块,用于根据所述各终端的权重和各所述终端发送的所述本地模型的模型参数进行模型聚合;
参数更新模块,用于将模型聚合得到的全局模型的模型参数发送至各所述终端,以使各所述终端根据接收到的所述全局模型的模型参数对所述本地模型的参数进行更新。
5.根据权利要求4所述的装置,其特征在于,所述参数接收模块,包括:
初始化子模块,用于初始化所述全局模型;
模型发送子模块,用于将初始化的全局模型发送至各所述终端;
数据保存子模块,用于接收各所述终端发送的本地数据并进行保存;
参数接收子模块,用于接收各所述终端发送的所述共享验证数据和所述本地模型的模型参数。
6.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现权利要求1-3任一所述的方法步骤。
7.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现权利要求1-3任一所述的方法步骤。
CN202110314849.6A 2021-03-24 2021-03-24 一种联邦学习方法、装置、电子设备及存储介质 Active CN113158550B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110314849.6A CN113158550B (zh) 2021-03-24 2021-03-24 一种联邦学习方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110314849.6A CN113158550B (zh) 2021-03-24 2021-03-24 一种联邦学习方法、装置、电子设备及存储介质

Publications (2)

Publication Number Publication Date
CN113158550A CN113158550A (zh) 2021-07-23
CN113158550B true CN113158550B (zh) 2022-08-26

Family

ID=76884594

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110314849.6A Active CN113158550B (zh) 2021-03-24 2021-03-24 一种联邦学习方法、装置、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN113158550B (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114800545B (zh) * 2022-01-18 2023-10-27 泉州华中科技大学智能制造研究院 一种基于联邦学习的机器人控制方法
WO2024026583A1 (zh) * 2022-07-30 2024-02-08 华为技术有限公司 一种通信方法和通信装置

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112182595B (zh) * 2019-07-03 2024-03-26 北京百度网讯科技有限公司 基于联邦学习的模型训练方法及装置
CN110442457A (zh) * 2019-08-12 2019-11-12 北京大学深圳研究生院 基于联邦学习的模型训练方法、装置及服务器
CN111814985B (zh) * 2020-06-30 2023-08-29 平安科技(深圳)有限公司 联邦学习网络下的模型训练方法及其相关设备

Also Published As

Publication number Publication date
CN113158550A (zh) 2021-07-23

Similar Documents

Publication Publication Date Title
CN113158550B (zh) 一种联邦学习方法、装置、电子设备及存储介质
CN113282960B (zh) 一种基于联邦学习的隐私计算方法、装置、系统及设备
WO2018130201A1 (zh) 确定关联账号的方法、服务器及存储介质
WO2022088541A1 (zh) 一种基于差分进化的联邦学习激励方法和系统
CN108965951B (zh) 广告的播放方法及装置
CN112884016B (zh) 云平台可信评估模型训练方法和云平台可信评估方法
CN111652371A (zh) 一种离线强化学习网络训练方法、装置、系统及存储介质
WO2023000491A1 (zh) 一种应用推荐方法、装置、设备及计算机可读存储介质
CN116627970A (zh) 一种基于区块链和联邦学习的数据共享方法及装置
CN114116705A (zh) 联合学习中确定参与方贡献值的方法及装置
CN114116707A (zh) 确定联合学习中参与方贡献度的方法及装置
CN114357526A (zh) 抵御推断攻击的医疗诊断模型差分隐私联合训练方法
CN108805332B (zh) 一种特征评估方法和装置
CN111510473B (zh) 访问请求处理方法、装置、电子设备和计算机可读介质
CN110349571B (zh) 一种基于连接时序分类的训练方法及相关装置
CN113378994A (zh) 一种图像识别方法、装置、设备及计算机可读存储介质
CN108768743A (zh) 一种用户识别方法、装置及服务器
CN116362894A (zh) 多目标学习方法、装置、电子设备及计算机可读存储介质
CN114116740A (zh) 用于联合学习中确定参与方贡献度的方法及装置
CN111585739B (zh) 一种相位调整方法及装置
CN111416595B (zh) 一种基于多核融合的大数据滤波方法
CN114553869A (zh) 基于联合学习的确定资源贡献度的方法、装置和电子设备
CN108880935B (zh) 网络节点重要度的获得方法和装置、设备、存储介质
CN109255099B (zh) 一种计算机可读存储介质、数据处理方法、装置及服务器
CN112751924B (zh) 一种数据推送方法、系统及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant