CN111695696A - 一种基于联邦学习的模型训练的方法及装置 - Google Patents

一种基于联邦学习的模型训练的方法及装置 Download PDF

Info

Publication number
CN111695696A
CN111695696A CN202010534434.5A CN202010534434A CN111695696A CN 111695696 A CN111695696 A CN 111695696A CN 202010534434 A CN202010534434 A CN 202010534434A CN 111695696 A CN111695696 A CN 111695696A
Authority
CN
China
Prior art keywords
parameter
matrix
gradient
update
bias
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010534434.5A
Other languages
English (en)
Inventor
李晓丽
车春江
李煜政
陈川
郑子彬
严强
李辉忠
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Sun Yat Sen University
WeBank Co Ltd
Original Assignee
Sun Yat Sen University
WeBank Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Sun Yat Sen University, WeBank Co Ltd filed Critical Sun Yat Sen University
Priority to CN202010534434.5A priority Critical patent/CN111695696A/zh
Publication of CN111695696A publication Critical patent/CN111695696A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Artificial Intelligence (AREA)
  • Software Systems (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Medical Informatics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种基于联邦学习的模型训练的方法及装置,包括:终端获取中央服务器第k次迭代的第一矩阵参数和第二矩阵参数,其中,第一矩阵参数和第二矩阵参数是中央服务器对全局模型参数矩阵进行分解得到的,k为自然数,以此减少终端的模型的参数,降低终端的模型训练时运行所需的内存消耗,然后终端使用训练样本进行训练,确定出第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,通过将第一矩阵参数的更新梯度和偏置参数的更新梯度发送至中央服务器和/或将第二矩阵参数的更新梯度和偏置参数的更新梯度发送至所述中央服务器,减少计算数据,以使中央服务器更新全局模型参数矩阵。

Description

一种基于联邦学习的模型训练的方法及装置
技术领域
本发明涉及金融科技(Fintech)领域,尤其涉及一种基于联邦学习的模型训练的方法及装置。
背景技术
随着计算机技术的发展,越来越多的技术应用在金融领域,传统金融业正在逐步向金融科技转变,但由于金融行业的安全性、实时性要求,也对技术提出的更高的要求。在金融领域,对联邦学习模型进行训练是一个重要的问题。
终端,如手机、平板电脑、可穿戴设备、区块链节点、和自动交通工具等,变得越来越受欢迎。这些设备每天都射产生大数量的有价值的数据,通过这些数据训练得到的模型可以极大提高用户的体验,例如,经过训练得到的语音模型可以提高语音识别和文字输入的性能,图像模型可以提高选择图片的能力,若将待训练数据发送至中央服务器,则可以训练出相对应的模型,但这些数据通常是受到保护的,并不能随意使用。
在现有技术中,根据联邦学习,将终端与其他终端协作或将区块链某一节点与其他节点协作,根据终端中的本地数据(或区块链节点的本地数据)在本地训练出模型,使得终端或区块链(联盟链)节点可以不用上传在本地的数据,如图1所示,图1为一种联邦学习示意图,通过联邦学习建立模型,在该模型中,每一轮通信中,终端或区块链(联盟链)节点会连接到中央服务器,并下载一个全局模型,然后根据本地数据对全局模型进行训练,将训练完成后得到的更新的梯度发送至中央服务器,以使中央服务器对全局模型进行更新。进而实现终端或区块链(联盟链)节点拥有一个全局模型,并且可以根据本地的识别任务中。
然而现有技术中的联邦学习难以部署到终端或区块链(联盟链)节点,终端或区块链(联盟链)节点的硬件平台无法跟上深度神经网络的指数级增长,因为终端或区块链(联盟链)节点的资源有限,难以部署复杂的神经网络,且模型运行所需的内存消耗大,导致模型的效率低。
发明内容
本发明实施例提供一种基于联邦学习的模型训练的方法及装置,用于减少终端或区块链(联盟链)节点的模型的参数,在不影响终端或区块链(联盟链)节点的模型训练的效率下,降低终端或区块链(联盟链)节点的模型训练时运行所需的内存消耗。
第一方面,本发明实施例提供一种运用于终端或区块链节点的联邦学习的模型训练的方法,包括:
终端获取中央服务器第k次迭代的第一矩阵参数和第二矩阵参数;所述第一矩阵参数和所述第二矩阵参数是所述中央服务器对全局模型参数矩阵进行分解得到的;所述k为自然数;
所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和偏置参数的更新梯度;
所述终端将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器和/或将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器,以使所述中央服务器更新所述全局模型参数矩阵。
上述技术方案中,通过终端获取中央服务器分解的第k次迭代的第一矩阵参数和第二矩阵参数,以使终端的模型减少参数,降低终端的模型训练时的内存消耗,通过终端使用训练样本得到第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,用于发送给中央服务器,使中央服务器更新全局模型参数矩阵,以此实现在降低终端的模型训练时运行所需的内存消耗情况下,不影响终端的模型训练的效率,通过将第一矩阵参数的更新梯度和偏置参数的更新梯度发送至中央服务器和/或将第二矩阵参数的更新梯度和偏置参数的更新梯度发送至所述中央服务器,减少计算数据,节省了终端的模型训练时的计算时间。
可选的,所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和偏置参数的更新梯度,包括:
所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,得到所述目标函数中所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度;
所述终端根据所述第一矩阵参数、所述第二矩阵参数和所述偏置参数、所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度,确定出第一矩阵更新参数、第二矩阵更新参数和偏置更新参数;
所述终端将所述第一矩阵更新参数、所述第二矩阵更新参数和所述偏置更新参数进行反向传播计算,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度。
上述技术方案中,终端通过使用训练样本对第一矩阵参数和第二矩阵参数进行训练,进而得到第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,降低了终端的模型训练时的内存消耗。
可选的,所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,得到所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度,包括:
根据第一矩阵参数和所述第二矩阵参数创建转换函数,使用所述训练样本进行前向传播训练,确定出所述转换函数在前向传播中所有全连接层的损失函数;所述损失函数包括所述偏置参数;
确定出所述转换函数在前向传播中所有全连接层的损失函数的最小值,将所述最小值的损失函数作为目标函数;
根据逐元素乘法计算所述目标函数,确定出所述目标函数的误差函数;
根据随机梯度下降法计算所述误差函数,确定出所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度。
上述技术方案中,根据确定出的目标函数得到较小误差最快方向的第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,以降低终端模型训练的内存消耗。
可选的,根据下述公式(1)确定出第一矩阵更新参数;根据下述公式(2)确定出第二矩阵更新参数;根据下述公式(3)确定出其他更新参数;
Figure BDA0002536538820000041
Figure BDA0002536538820000042
Figure BDA0002536538820000043
其中,
Figure BDA0002536538820000044
为第k+1次迭代的第l层的第一矩阵更新参数,
Figure BDA0002536538820000045
为第k次迭代的第l层的第一矩阵参数,αk+1为第k+1次迭代的学习率,
Figure BDA0002536538820000046
为所述第一矩阵参数的梯度,
Figure BDA0002536538820000047
为第k+1次迭代的第l层的第二矩阵更新参数,
Figure BDA0002536538820000048
为第k次迭代的第l层的第二矩阵参数,
Figure BDA0002536538820000049
为所述第二矩阵参数的梯度,
Figure BDA00025365388200000410
为第k+1次迭代的第l层的偏置更新参数,
Figure BDA00025365388200000411
为第k次迭代的第l层的偏置参数,
Figure BDA00025365388200000412
为所述偏置参数的更新梯度,其中l为正整数。
可选的,所述终端的数量为多个;
所述方法还包括:
当所述终端的数量为偶数时,随机一半数量的终端将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;另一半数量的终端将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;当所述终端的数量为奇数时,每个终端将所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;以使所述中央服务器根据多个所述终端发送的所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度更新所述全局模型参数矩阵。
第二方面,本发明实施例提供一种运用于终端或区块链节点的联邦学习的模型训练的方法,包括:
中央服务器获取全局模型参数矩阵;
所述中央服务器将所述全局模型参数矩阵分解为第一矩阵参数和第二矩阵参数;
所述中央服务器将所述第一矩阵参数和所述第二矩阵参数发送至多个终端;以使所述多个终端对所述第一矩阵参数和所述第二矩阵参数进行训练。
中央服务器通过将全局模型参数矩阵分解为第一矩阵参数和第二矩阵参数,并发送至多个终端,以使多个终端对第一矩阵参数和第二矩阵参数进行训练,实现多个终端的模型减少参数,降低多个终端的模型训练时的内存消耗。
可选的,所述中央服务器将所述第一矩阵参数和所述第二矩阵参数发送至多个终端之后,还包括:
所述中央服务器获取所述多个终端发送的第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度;
所述中央服务器根据多个所述第一矩阵参数的更新梯度、多个所述第二矩阵参数的更新梯度和多个所述偏置参数的更新梯度,更新所述全局模型参数矩阵。
可选的,所述中央服务器根据多个所述第一矩阵参数的更新梯度、多个所述第二矩阵参数的更新梯度和多个所述偏置参数的更新梯度,更新所述全局模型参数矩阵,包括:
所述中央服务器将多个所述第一矩阵参数的更新梯度、多个所述第二矩阵参数的更新梯度和多个所述偏置参数的更新梯度进行联合平均计算,得到所述第一矩阵参数的平均梯度、所述第二矩阵参数的平均梯度和所述偏置参数的平均梯度;
所述中央服务器将所述第一矩阵参数、所述第二矩阵参数和所述偏置参数与所述第一矩阵参数的平均梯度、所述第二矩阵参数的平均梯度和所述偏置参数的平均梯度对应求和,确定出所述第一矩阵参数的阶跃向量、所述第二矩阵参数的阶跃向量和所述偏置参数的阶跃向量,更新所述全局模型参数矩阵。
第三方面,本发明实施例提供一种运用于终端或区块链节点的联邦学习的模型训练的装置,包括:
获取模块,用于获取中央服务器第k次迭代后发送的第一矩阵参数和第二矩阵参数;所述第一矩阵参数和所述第二矩阵参数是所述中央服务器对全局模型参数矩阵进行分解得到的;所述k为自然数;
处理模块,用于使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和偏置参数的更新梯度;
将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器和/或将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器,以使所述中央服务器更新所述全局模型参数矩阵。
可选的,所述处理模块具体用于:
使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,得到所述目标函数中所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度;
根据所述第一矩阵参数、所述第二矩阵参数和所述偏置参数、所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度,确定出第一矩阵更新参数、第二矩阵更新参数和偏置更新参数;
将所述第一矩阵更新参数、所述第二矩阵更新参数和所述偏置更新参数进行反向传播计算,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度。
可选的,所述处理模块具体用于:
根据第一矩阵参数和所述第二矩阵参数创建转换函数,使用所述训练样本进行前向传播训练,确定出所述转换函数在前向传播中所有全连接层的损失函数;所述损失函数包括所述偏置参数;
确定出所述转换函数在前向传播中所有全连接层的损失函数的最小值,将所述最小值的损失函数作为目标函数;
根据逐元素乘法计算所述目标函数,确定出所述目标函数的误差函数;
根据随机梯度下降法计算所述误差函数,确定出所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度。
可选的,所述处理模块具体用于:
根据下述公式(1)确定出第一矩阵更新参数;根据下述公式(2)确定出第二矩阵更新参数;根据下述公式(3)确定出其他更新参数;
Figure BDA0002536538820000071
Figure BDA0002536538820000072
Figure BDA0002536538820000073
其中,
Figure BDA0002536538820000074
为第k+1次迭代的第l层的第一矩阵更新参数,
Figure BDA0002536538820000075
为第k次迭代的第l层的第一矩阵参数,αk+1为第k+1次迭代的学习率,
Figure BDA0002536538820000076
为所述第一矩阵参数的梯度,
Figure BDA0002536538820000077
为第k+1次迭代的第l层的第二矩阵更新参数,
Figure BDA0002536538820000078
为第k次迭代的第l层的第二矩阵参数,
Figure BDA0002536538820000079
为所述第二矩阵参数的梯度,
Figure BDA00025365388200000710
为第k+1次迭代的第l层的偏置更新参数,
Figure BDA00025365388200000711
为第k次迭代的第l层的偏置参数,
Figure BDA00025365388200000712
为所述偏置参数的更新梯度,其中l为正整数。
可选的,终端的数量为多个;
当所述终端的数量为偶数时,随机一半数量的终端的处理模块将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;另一半数量的终端的处理模块将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;当所述终端的数量为奇数时,每个终端的处理模块将所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;以使所述中央服务器根据多个所述终端发送的所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度更新所述全局模型参数矩阵。
第四方面,本发明实施例提供一种运用于终端或区块链节点的联邦学习的模型训练的装置,包括:
获取单元,用于获取全局模型参数矩阵;
处理单元,用于将所述全局模型参数矩阵分解为第一矩阵参数和第二矩阵参数;
将所述第一矩阵参数和所述第二矩阵参数发送至多个终端;以使所述多个终端对所述第一矩阵参数和所述第二矩阵参数进行训练。
可选的,所述处理单元还用于:
将所述第一矩阵参数和所述第二矩阵参数发送至多个终端之后,控制获取单元获取所述多个终端发送的第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度;
根据多个所述第一矩阵参数的更新梯度、多个所述第二矩阵参数的更新梯度和多个所述偏置参数的更新梯度,更新所述全局模型参数矩阵。
可选的,所述处理单元具体用于:
所述中央服务器将多个所述第一矩阵参数的更新梯度、多个所述第二矩阵参数的更新梯度和多个所述偏置参数的更新梯度进行联合平均计算,得到所述第一矩阵参数的平均梯度、所述第二矩阵参数的平均梯度和所述偏置参数的平均梯度;
所述中央服务器将当前全局模型中的第一矩阵参数、第二矩阵参数和偏置参数与所述第一矩阵的平均梯度、所述第二矩阵的平均梯度和所述偏置参数的平均梯度对应求和,确定出所述第一矩阵的阶跃向量、所述第二矩阵的阶跃向量和所述偏置参数的阶跃向量,更新所述全局模型参数矩阵。
第五方面,本发明实施例还提供一种计算设备,包括:
存储器,用于存储程序指令;
处理器,用于调用所述存储器中存储的程序指令,按照获得的程序执行上述运用于终端或区块链节点的联邦学习的模型训练的方法。
第六方面,本发明实施例还提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令用于使计算机执行上述运用于终端或区块链节点的联邦学习的模型训练的方法。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简要介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的一种联邦学习示意图;
图2为本发明实施例提供的一种系统架构示意图;
图3为本发明实施例提供的一种运用于终端或区块链节点的联邦学习的模型训练的方法的流程示意图;
图4为本发明实施例提供的一种运用于终端或区块链节点的联邦学习的模型训练的方法的流程示意图;
图5为本发明实施例提供的一种运用于终端或区块链节点的联邦学习的模型训练的结构示意图;
图6为本发明实施例提供的一种运用于终端或区块链节点的联邦学习的模型训练的结构示意图。
具体实施方式
为了使本发明的目的、技术方案和优点更加清楚,下面将结合附图对本发明作进一步地详细描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
图2示例性的示出了本发明实施例所适用的一种系统架构,该系统架构包括中央服务器100和多个终端200。
其中,中央服务器100用于与多个终端200进行连接,并向多个终端200发送分解的第k次迭代的第一矩阵参数和第二矩阵参数。
终端200用于获取中央服务器100发送的第一矩阵参数和第二矩阵参数,并根据第一矩阵参数和第二矩阵参数使用训练样本进行训练,得到第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,发送至中央服务器100,以使中央服务器100更新全局模型参数矩阵,完成一次训练迭代。
本发明实施例中的终端可以为终端(手机)、个人台式电脑、笔记本电脑平板电脑、可穿戴设备、智能手表、和自动交通工具和区块链节点等设备。
需要说明的是,上述图2所示的结构仅是一种示例,本发明实施例对此不做限定。
基于上述描述,图3示例性的示出了本发明实施例提供的一种运用于终端或区块链节点的联邦学习的模型训练的方法的流程,该流程可由运用于终端或区块链节点的联邦学习的模型训练的装置执行。
如图3所示,该流程具体包括:
步骤301,终端获取中央服务器第k次迭代的第一矩阵参数和第二矩阵参数;所述第一矩阵参数和所述第二矩阵参数是所述中央服务器对全局模型参数矩阵进行分解得到的;所述k为自然数。
本发明实施例,在联邦学习中,每一次迭代,终端都需要获取中央服务器的全局模型,作为本地训练的模型,模型中包括第k次迭代的第一矩阵参数和第二矩阵参数,作为终端的模型,然后对训练样本进行训练,其中训练样本是终端的本地的数据。
需要说明的是,第一矩阵参数和第二矩阵参数是中央服务器对全局模型中的全局模型参数矩阵进行奇异值分解后再进行合并得到的,具体的,中央服务器对全局模型中的全局模型参数矩阵进行奇异值分解,得到矩阵一、矩阵二和矩阵三,然后将矩阵三开方后分别与矩阵一和矩阵二相乘,得到第一矩阵参数和第二矩阵参数。其中,矩阵三为对角矩阵,由全局模型参数矩阵的秩组成,矩阵的秩的值是可以依据经验设置的值,一般取值在10-20之间,若矩阵的秩的值越小,则分解后的矩阵越小,通过分解后的矩阵进行模型训练的结果误差相对越大,反之,若矩阵的秩的值越大,则分解后的矩阵越大,越接近于起始矩阵,通过分解后的矩阵进行模型训练的结果误差相对越小。
步骤302,所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和偏置参数的更新梯度。
本发明实施例,终端使用训练样本对第一矩阵参数和第二矩阵参数进行训练,训练出第一矩阵参数的梯度、第二矩阵参数的梯度和偏置参数的梯度,再根据第一矩阵参数的梯度、第二矩阵参数的梯度和偏置参数的梯度得到第一矩阵更新参数、第二矩阵更新参数和偏置更新参数,然后根据第一矩阵更新参数、第二矩阵更新参数和偏置更新参数得到第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度。
进一步地,终端使用训练样本对第一矩阵参数和第二矩阵参数进行训练,得到目标函数中第一矩阵参数的梯度、第二矩阵参数的梯度和偏置参数的梯度;
终端根据第一矩阵参数、第二矩阵参数和偏置参数、第一矩阵参数的梯度、第二矩阵参数的梯度和偏置参数的梯度,确定出第一矩阵更新参数、第二矩阵更新参数和偏置更新参数;
终端将第一矩阵更新参数、第二矩阵更新参数和偏置更新参数进行反向传播计算,确定出第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度。
具体的,终端根据第一矩阵参数和第二矩阵参数创建转换函数,使用训练样本进行前向传播训练,确定出转换函数在前向传播中所有全连接层的损失函数;所述损失函数包括所述偏置参数;确定出转换函数在前向传播中所有全连接层的损失函数的最小值,将最小值的损失函数作为目标函数;根据逐元素乘法计算所述目标函数,确定出所述目标函数的误差函数;根据随机梯度下降法计算误差函数,确定出第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度。
终端使用训练样本对第一矩阵参数和第二矩阵参数进行前向传播训练,得到前向传播中第一层至最后一层全连接层的所有转换函数,然后根据训练样本继续训练,得到第一层至最后一层全连接层的所有损失函数,然后根据损失函数的值,将值最小值对应的损失函数确定为目标函数,其中,目标函数包括第一矩阵参数、第二矩阵参数和偏置参数,然后再根据元素乘法对目标函数进行计算,得到目标函数对应的误差函数,误差函数包括损失函数最小值对应的最后一层全连接层的目标函数的误差函数,和损失函数最小值对应的非最后一层全连接层的目标函数的误差函数,最后再使用梯度下降法中的随机梯度下降法得到目标函数的第一矩阵参数的梯度、第二矩阵参数的梯度和前置参数的梯度,其中,使用随机梯度下降方法的有益在于提升计算速度,也可以使用批量梯度的方法进行计算,在这里不做限制。
然后将得到的第一矩阵参数的梯度、第二矩阵参数的梯度和前置参数的梯度与学习率做乘积,在与当前的第一矩阵参数、第二矩阵参数和偏置参数分别求和,将结果确定为第一矩阵更新参数、第二矩阵更新参数和偏置更新参数,具体的,根据下述公式(1)确定出第一矩阵更新参数;根据下述公式(2)确定出第二矩阵更新参数;根据下述公式(3)确定出其他更新参数;
Figure BDA0002536538820000131
Figure BDA0002536538820000132
Figure BDA0002536538820000133
其中,
Figure BDA0002536538820000134
为第k+1次迭代的第l层的第一矩阵更新参数,
Figure BDA0002536538820000135
为第k次迭代的第l层的第一矩阵参数,αk为第k次迭代的学习率,
Figure BDA0002536538820000136
为第一矩阵参数的梯度,
Figure BDA0002536538820000137
为第k+1次迭代的第l层的第二矩阵更新参数,
Figure BDA0002536538820000138
为第k次迭代的第l层的第二矩阵参数,
Figure BDA0002536538820000139
为第二矩阵参数的梯度,
Figure BDA00025365388200001310
为第k+1次迭代的第l层的偏置更新参数,
Figure BDA00025365388200001311
为第k次迭代的第l层的偏置参数,
Figure BDA00025365388200001312
为偏置参数的更新梯度,其中l为正整数。
然后将第一矩阵更新参数、第二矩阵更新参数和偏置更新参数通过终端当前的模型进行反向传播训练,得到第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度。
下面将在具体实例中描述上述技术方案。
实例1
终端获取中央服务器第k次迭代的第一矩阵参数和第二矩阵参数,Al和Bl,其中l指的是神经网络中第l层全连接层,l为正整数。
终端使用训练样本对第一矩阵参数和第二矩阵参数进行训练,则第l层全连接层上的函数为下述公式(4)。
Figure BDA00025365388200001313
其中,al为函数输出,bl为第l层全连接层中的偏置参数。
当终端获取的模型中共有m层连接层时,则am为下述公式(5),m为不小于l的正整数。
Figure BDA0002536538820000141
其中,x为训练样本中数据,用于作为函数的输入变量。
根据公式(4)和(5),确定出所有全连接层的损失函数,得到第l层全连接层中的目标函数,目标函数为公式(6)。由公式(6)可以看出,第l层的全连接层的损失函数的值最小。
Figure BDA0002536538820000142
根据逐元素乘法计算目标函数公式(6),得到目标函数公式(6)的误差函数,误差函数为公式(7)
Figure BDA0002536538820000143
其中,⊙指的逐元素乘法,
Figure BDA0002536538820000144
为al的梯度。
再通过随机梯度下降法对公式(7)进行计算,确定出第一矩阵参数的更新梯度为:
Figure BDA0002536538820000145
第二矩阵参数的更新梯度为:
Figure BDA0002536538820000146
偏置参数的更新梯度为:
Figure BDA0002536538820000147
然后根据公式(1)得到第一矩阵更新参数
Figure BDA0002536538820000148
根据公式(2)得到第二矩阵更新参数
Figure BDA0002536538820000149
根据公式(3)得到偏置更新参数
Figure BDA00025365388200001410
然后将模型进行反向传播,根据第一矩阵更新参数、第二矩阵更新参数和偏置更新参数确定出第一矩阵参数的更新梯度
Figure BDA00025365388200001411
Figure BDA00025365388200001412
第二矩阵参数的更新梯度
Figure BDA00025365388200001413
Figure BDA00025365388200001414
偏置参数的更新梯度
Figure BDA0002536538820000151
Figure BDA0002536538820000152
步骤303,所述终端将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器和/或将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器,以使所述中央服务器更新所述全局模型参数矩阵。
本发明实施例,根据向中央服务器发送第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度的终端的数量,确定出终端向中央服务器发送参数的类型。若终端为偶数时,可以选择一半的终端发送第一矩阵参数的更新梯度和偏置参数的更新梯度至中央服务器,另一半的终端发送第二矩阵参数的更新梯度和偏置参数的更新梯度至中央服务器,否则终端将第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数全部发送至中央服务器。
示例性的,终端的数量为多个,当终端的数量为偶数时,随机一半数量的终端将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;另一半数量的终端将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;当所述终端的数量为奇数时,每个终端将所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;以使所述中央服务器根据多个所述终端发送的所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度更新所述全局模型参数矩阵。
本发明实施例中,中央服务器将全局模型参数矩阵发送至多个终端,以使终端根据参数矩阵进行模型训练,在参与的终端的数量为偶数时(如10个终端),随机一半的终端只发送第一矩阵参数的更新梯度和偏置参数的更新梯度发送至中央服务器(如随机选择10个终端中的5个终端的第一矩阵参数的更新梯度和偏置参数的更新梯度发送至中央服务器),剩下的另一半终端发送第二矩阵参数的更新梯度和偏置参数的更新梯度至中央服务器,可以使中央服务器在不影响计算结果及模型训练准确度的情况下,减少中央服务器的计算数据量,提升中央服务器的计算速度,减少迭代运行时间。
本发明实施例中,通过分解后的第一矩阵参数和第二矩阵参数模型训练,以使终端的模型减少参数,降低终端的模型训练时的内存消耗,还可以通过在终端的数量为偶数时,选择性的不发送第一矩阵参数的更新梯度或第二矩阵参数的更新梯度,以减少模型训练时的数据量,节省计算时间。
图4为本发明实施例提供的一种运用于移动终端或区块链节点的联邦学习的模型训练的方法的流程,如图4所示,具体流程包括:
步骤401,中央服务器获取全局模型参数矩阵。
本发明实施例,中央服务器通过全局模型获取到全局模型参数矩阵。
步骤402,所述中央服务器将所述全局模型参数矩阵分解为第一矩阵参数和第二矩阵参数。
本发明实施例,中央服务器对全局模型中的全局模型参数矩阵进行奇异值分解,得到分解结果,再将分解结果进行合并,得到第一矩阵参数和第二矩阵参数,例如,全局模型中的全局模型参数矩阵为Wl,通过奇异值分解后,Wl=USVT,其中,
Figure BDA0002536538820000163
S∈Rr×r,S为对角矩阵,r为可以选择的奇异值分解中矩阵的秩,则确定第一矩阵参数
Figure BDA0002536538820000161
第二矩阵参数
Figure BDA0002536538820000162
如,Wl是大小为1000×1000的矩阵,指定r=10,则Al是大小为1000×10的矩阵,Bl是大小为10×1000的矩阵,以此分解全局模型中的全局模型参数矩阵。
步骤403,所述中央服务器将所述第一矩阵参数和所述第二矩阵参数发送至多个终端;以使所述多个终端对所述第一矩阵参数和所述第二矩阵参数进行训练。
本发明实施例,中央服务器将分解得到的第一矩阵参数和第二矩阵参数发送至多个终端,以使多个终端减少终端的模型的参数,并对第一矩阵参数和第二矩阵参数进行训练。
中央服务器将第一矩阵参数和第二矩阵参数发送至多个终端之后,获取多个终端发送的第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,中央服务器再根据多个第一矩阵参数的更新梯度、多个所述第二矩阵参数的更新梯度和多个所述偏置参数的更新梯度,更新所述全局模型参数矩阵。
中央服务器将第一矩阵参数和第二矩阵参数发送至多个终端,以使终端根据对第一矩阵参数和第二矩阵参数进行训练得到第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,然后根据终端的数量获取第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,更新全局模型参数矩阵。
进一步地,中央服务器将多个所述第一矩阵参数的更新梯度、多个第二矩阵参数的更新梯度和多个偏置参数的更新梯度进行联合平均计算,得到第一矩阵参数的平均梯度、第二矩阵参数的平均梯度和偏置参数的平均梯度;
中央服务器将第一矩阵参数、第二矩阵参数和偏置参数与第一矩阵参数的平均梯度、第二矩阵参数的平均梯度和所述偏置参数的平均梯度对应求和,确定出第一矩阵参数的阶跃向量、第二矩阵参数的阶跃向量和偏置参数的阶跃向量,更新全局模型参数矩阵。
具体的,将多个第一矩阵参数的更新梯度相加,再根据发送第一矩阵参数的更新梯度的终端的数量,确定出第一矩阵参数的平均梯度,同理确定出第二矩阵参数的平均梯度和偏置参数的平均梯度,然后将第一矩阵参数的平均梯度与第一矩阵参数相加,得到第一矩阵参数的阶跃向量,同理确定出第二矩阵参数的阶跃向量和偏置参数的阶跃向量,根据第一矩阵参数的阶跃向量、第二矩阵参数的阶跃向量和偏置参数的阶跃向量更新全局模型参数矩阵。
结合图3的实例1,下面在具体实例中更新全局模型参数矩阵。
实例2
中央服务器获取到全局模型参数矩阵W为下述矩阵。
Figure BDA0002536538820000181
将W进行分解,得到第一参数矩阵A为下述矩阵:
Figure BDA0002536538820000182
得到第二参数矩阵BT下述矩阵:
Figure BDA0002536538820000183
将第一矩阵参数A和第二矩阵参数BT发送至6台终端,并得到终端发送的第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度,分别为
Figure BDA0002536538820000184
其中,m∈{1,2,3,4,5,6},m表示终端的编号,编号为1、2、和3的终端发送的是第一矩阵参数的更新梯度和偏置参数的更新梯度,编号为4、5、和6的终端发送的是第二矩阵参数的更新梯度和偏置参数的更新梯度,中央服务器将获取到的参数进行联合平均,得到第一矩阵的平均梯度、第二矩阵的平均梯度和所述偏置参数的平均梯度分别为:
Figure BDA0002536538820000185
Figure BDA0002536538820000186
再将第一矩阵参数、第二矩阵参数和偏置参数与第一矩阵参数的平均梯度、第二矩阵参数的平均梯度和偏置参数的平均梯度对应求和,确定出第一矩阵参数的阶跃向量、第二矩阵参数的阶跃向量和偏置参数的阶跃向量分别为
Figure BDA0002536538820000187
其中,k表示第k次迭代,然后根据第一矩阵参数的阶跃向量、第二矩阵参数的阶跃向量和偏置参数的阶跃向量更新全局模型参数矩阵,用于下一次迭代。
本发明实施例中,通过中央服务器分解全局模型参数矩阵为第一矩阵参数和第二矩阵参数,以使终端的模型减少参数,降低终端的模型训练时的内存消耗,且不影响终端的模型训练的效率。
基于相同的技术构思,图5示例性的示出了本发明实施例提供的一种移动运用于终端或区块链节点的联邦学习的模型训练的装置的结构,该装置可以执行图3中运用于移动终端或区块链节点的联邦学习的模型训练的方法的流程。
如图5所示,该装置具体包括:
获取模块501,用于获取中央服务器第k次迭代后发送的第一矩阵参数和第二矩阵参数;所述第一矩阵参数和所述第二矩阵参数是所述中央服务器对全局模型参数矩阵进行分解得到的;所述k为自然数;
处理模块502,用于使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和偏置参数的更新梯度;
将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器和/或将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器,以使所述中央服务器更新所述全局模型参数矩阵。
可选的,所述处理模块502具体用于:
使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,得到所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度;
根据所述第一矩阵参数、所述第二矩阵参数和所述偏置参数、所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度,确定出第一矩阵更新参数、第二矩阵更新参数和偏置更新参数;
将所述第一矩阵更新参数、所述第二矩阵更新参数和所述偏置更新参数进行反向传播计算,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度。
可选的,所述处理模块502具体用于:
根据第一矩阵参数和所述第二矩阵参数创建转换函数,使用所述训练样本进行前向传播训练,确定出所述转换函数在前向传播中所有全连接层的损失函数;所述损失函数包括所述偏置参数;
确定出所述转换函数在前向传播中所有全连接层的损失函数的最小值,将所述最小值的损失函数作为目标函数;
根据逐元素乘法计算所述目标函数,确定出所述目标函数的误差函数;
根据随机梯度下降法计算所述误差函数,确定出所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度。
可选的,所述处理模块502具体用于:
根据下述公式(1)确定出第一矩阵更新参数;根据下述公式(2)确定出第二矩阵更新参数;根据下述公式(3)确定出其他更新参数;
Figure BDA0002536538820000201
Figure BDA0002536538820000202
Figure BDA0002536538820000203
其中,
Figure BDA0002536538820000204
为第k+1次迭代的第l层的第一矩阵更新参数,
Figure BDA00025365388200002012
为第k次迭代的第l层的第一矩阵参数,αk+1为第k+1次迭代的学习率,
Figure BDA0002536538820000205
为所述第一矩阵参数的梯度,
Figure BDA0002536538820000206
为第k+1次迭代的第l层的第二矩阵更新参数,
Figure BDA0002536538820000207
为第k次迭代的第l层的第二矩阵参数,
Figure BDA0002536538820000208
为所述第二矩阵参数的梯度,
Figure BDA0002536538820000209
为第k+1次迭代的第l层的偏置更新参数,
Figure BDA00025365388200002010
为第k次迭代的第l层的偏置参数,
Figure BDA00025365388200002011
为所述偏置参数的更新梯度,其中l为正整数。
可选的,终端的数量为多个;
当所述终端的数量为偶数时,随机一半数量的终端的处理模块502将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;另一半数量的终端的处理模块将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;当所述终端的数量为奇数时,每个终端的处理模块502将所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;以使所述中央服务器根据多个所述终端发送的所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度更新所述全局模型参数矩阵。
图6示例性的示出了本发明实施例提供的一种运用于移动终端或区块链节点的联邦学习的模型训练的装置的结构,该装置可以执行图4中运用于移动终端或区块链节点的联邦学习的模型训练的方法的流程。
如图6所示,该装置具体包括:
获取单元601,用于获取全局模型参数矩阵;
处理单元602,用于将所述全局模型参数矩阵分解为第一矩阵参数和第二矩阵参数;
将所述第一矩阵参数和所述第二矩阵参数发送至多个终端;以使所述多个终端对所述第一矩阵参数和所述第二矩阵参数进行训练。
可选的,所述处理单元602还用于:
将所述第一矩阵参数和所述第二矩阵参数发送至多个终端之后,控制获取单元获取所述多个终端发送的第一矩阵参数的更新梯度、第二矩阵参数的更新梯度和偏置参数的更新梯度;
根据多个所述第一矩阵参数的更新梯度、多个所述第二矩阵参数的更新梯度和多个所述偏置参数的更新梯度,更新所述全局模型参数矩阵。
可选的,所述处理单元602具体用于:
所述中央服务器将多个所述第一矩阵参数的更新梯度、多个所述第二矩阵参数的更新梯度和多个所述偏置参数的更新梯度进行联合平均计算,得到所述第一矩阵参数的平均梯度、所述第二矩阵参数的平均梯度和所述偏置参数的平均梯度;
所述中央服务器将所述第一矩阵参数、所述第二矩阵参数和所述偏置参数与所述第一矩阵参数的平均梯度、所述第二矩阵参数的平均梯度和所述偏置参数的平均梯度对应求和,确定出所述第一矩阵参数的阶跃向量、所述第二矩阵参数的阶跃向量和所述偏置参数的阶跃向量,更新所述全局模型参数矩阵。
基于相同的技术构思,本发明实施例还提供一种计算设备,包括:
存储器,用于存储程序指令;
处理器,用于调用存储器中存储的程序指令,按照获得的程序执行上述运用于移动终端或区块链节点的联邦学习的模型训练的方法。
基于相同的技术构思,本发明实施例还提供一种计算机可读存储介质,计算机可读存储介质存储有计算机可执行指令,计算机可执行指令用于使计算机执行上述运用于移动终端或区块链节点的联邦学习的模型训练的方法。
本领域内的技术人员应明白,本申请的实施例可提供为方法、系统、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本申请是参照根据本申请的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
显然,本领域的技术人员可以对本申请进行各种改动和变型而不脱离本申请的精神和范围。这样,倘若本申请的这些修改和变型属于本申请权利要求及其等同技术的范围之内,则本申请也意图包含这些改动和变型在内。

Claims (10)

1.一种基于联邦学习的模型训练的方法,其特征在于,包括:
终端获取中央服务器第k次迭代的第一矩阵参数和第二矩阵参数;所述第一矩阵参数和所述第二矩阵参数是所述中央服务器对全局模型参数矩阵进行分解得到的;所述k为自然数;
所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和偏置参数的更新梯度;
所述终端将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器和/或将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器,以使所述中央服务器更新所述全局模型参数矩阵。
2.如权利要求1所述的方法,其特征在于,所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和偏置参数的更新梯度,包括:
所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,得到所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度;
所述终端根据所述第一矩阵参数、所述第二矩阵参数和所述偏置参数、所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度,确定出第一矩阵更新参数、第二矩阵更新参数和偏置更新参数;
所述终端将所述第一矩阵更新参数、所述第二矩阵更新参数和所述偏置更新参数进行反向传播计算,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度。
3.如权利要求2所述的方法,其特征在于,所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,得到所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度,包括:
根据第一矩阵参数和所述第二矩阵参数创建转换函数,使用所述训练样本进行前向传播训练,确定出所述转换函数在前向传播中所有全连接层的损失函数;所述损失函数包括所述偏置参数;
确定出所述转换函数在前向传播中所有全连接层的损失函数的最小值,将所述最小值的损失函数作为目标函数;
根据逐元素乘法计算所述目标函数,确定出所述目标函数的误差函数;
根据随机梯度下降法计算所述误差函数,确定出所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度。
4.如权利要求2所述的方法,其特征在于,根据下述公式(1)确定出第一矩阵更新参数;根据下述公式(2)确定出第二矩阵更新参数;根据下述公式(3)确定出其他更新参数;
Figure FDA0002536538810000021
Figure FDA0002536538810000022
Figure FDA0002536538810000023
其中,
Figure FDA0002536538810000024
为第k+1次迭代的第l层的第一矩阵更新参数,
Figure FDA0002536538810000025
为第k次迭代的第l层的第一矩阵参数,αk+1为第k+1次迭代的学习率,
Figure FDA0002536538810000026
为所述第一矩阵参数的梯度,
Figure FDA0002536538810000027
为第k+1次迭代的第l层的第二矩阵更新参数,
Figure FDA0002536538810000028
为第k次迭代的第l层的第二矩阵参数,
Figure FDA0002536538810000029
为所述第二矩阵参数的梯度,
Figure FDA00025365388100000210
为第k+1次迭代的第l层的偏置更新参数,
Figure FDA00025365388100000211
为第k次迭代的第l层的偏置参数,
Figure FDA00025365388100000212
为所述偏置参数的更新梯度,其中l为正整数。
5.如权利要求1至4任一项所述的方法,其特征在于,所述终端的数量为多个;
所述方法还包括:
当所述终端的数量为偶数时,随机一半数量的终端将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;另一半数量的终端将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;当所述终端的数量为奇数时,每个终端将所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器;以使所述中央服务器根据多个所述终端发送的所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度更新所述全局模型参数矩阵。
6.一种基于联邦学习的模型训练的装置,其特征在于,包括
获取模块,用于获取中央服务器第k次迭代后发送的第一矩阵参数和第二矩阵参数;所述第一矩阵参数和所述第二矩阵参数是所述中央服务器对全局模型参数矩阵进行分解得到的;所述k为自然数;
处理模块,用于使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和偏置参数的更新梯度;
将所述第一矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器和/或将所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度发送至所述中央服务器,以使所述中央服务器更新所述全局模型参数矩阵。
7.如权利要求6所述的装置,其特征在于,所述处理模块具体用于:
所述终端使用训练样本对所述第一矩阵参数和所述第二矩阵参数进行训练,得到所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度;
所述终端根据所述第一矩阵参数、所述第二矩阵参数和所述偏置参数、所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度,确定出第一矩阵更新参数、第二矩阵更新参数和偏置更新参数;
所述终端将所述第一矩阵更新参数、所述第二矩阵更新参数和所述偏置更新参数进行反向传播计算,确定出所述第一矩阵参数的更新梯度、所述第二矩阵参数的更新梯度和所述偏置参数的更新梯度。
8.如权利要求7所述的装置,其特征在于,所述处理模块具体用于:
根据第一矩阵参数和所述第二矩阵参数创建转换函数,使用所述训练样本进行前向传播训练,确定出所述转换函数在前向传播中所有全连接层的损失函数;所述损失函数包括所述偏置参数;
确定出所述转换函数在前向传播中所有全连接层的损失函数的最小值,将所述最小值的损失函数作为目标函数;
根据逐元素乘法计算所述目标函数,确定出所述目标函数的误差函数;
根据随机梯度下降法计算所述误差函数,确定出所述第一矩阵参数的梯度、所述第二矩阵参数的梯度和所述偏置参数的梯度。
9.一种计算设备,其特征在于,包括:
存储器,用于存储程序指令;
处理器,用于调用所述存储器中存储的程序指令,按照获得的程序执行权利要求1至5任一项所述的方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令用于使计算机执行权利要求1至5任一项所述的方法。
CN202010534434.5A 2020-06-12 2020-06-12 一种基于联邦学习的模型训练的方法及装置 Pending CN111695696A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010534434.5A CN111695696A (zh) 2020-06-12 2020-06-12 一种基于联邦学习的模型训练的方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010534434.5A CN111695696A (zh) 2020-06-12 2020-06-12 一种基于联邦学习的模型训练的方法及装置

Publications (1)

Publication Number Publication Date
CN111695696A true CN111695696A (zh) 2020-09-22

Family

ID=72480757

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010534434.5A Pending CN111695696A (zh) 2020-06-12 2020-06-12 一种基于联邦学习的模型训练的方法及装置

Country Status (1)

Country Link
CN (1) CN111695696A (zh)

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111865815A (zh) * 2020-09-24 2020-10-30 中国人民解放军国防科技大学 一种基于联邦学习的流量分类方法及系统
CN112164224A (zh) * 2020-09-29 2021-01-01 杭州锘崴信息科技有限公司 信息安全的交通信息处理系统、方法、设备及存储介质
CN112288100A (zh) * 2020-12-29 2021-01-29 支付宝(杭州)信息技术有限公司 一种基于联邦学习进行模型参数更新的方法、系统及装置
CN112418440A (zh) * 2020-11-27 2021-02-26 网络通信与安全紫金山实验室 一种边-端协同梯度压缩聚合方法以及装置
CN113094761A (zh) * 2021-04-25 2021-07-09 中山大学 一种联邦学习数据防篡改监测方法及相关装置
CN113553377A (zh) * 2021-07-21 2021-10-26 湖南天河国云科技有限公司 基于区块链和联邦学习的数据共享方法及装置
CN114297722A (zh) * 2022-03-09 2022-04-08 广东工业大学 一种基于区块链的隐私保护异步联邦共享方法及系统
CN117351299A (zh) * 2023-09-13 2024-01-05 北京百度网讯科技有限公司 图像生成及模型训练方法、装置、设备和存储介质

Cited By (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111865815A (zh) * 2020-09-24 2020-10-30 中国人民解放军国防科技大学 一种基于联邦学习的流量分类方法及系统
CN111865815B (zh) * 2020-09-24 2020-11-24 中国人民解放军国防科技大学 一种基于联邦学习的流量分类方法及系统
CN112164224A (zh) * 2020-09-29 2021-01-01 杭州锘崴信息科技有限公司 信息安全的交通信息处理系统、方法、设备及存储介质
CN112418440A (zh) * 2020-11-27 2021-02-26 网络通信与安全紫金山实验室 一种边-端协同梯度压缩聚合方法以及装置
CN112418440B (zh) * 2020-11-27 2024-02-13 网络通信与安全紫金山实验室 一种边-端协同梯度压缩聚合方法以及装置
CN112288100A (zh) * 2020-12-29 2021-01-29 支付宝(杭州)信息技术有限公司 一种基于联邦学习进行模型参数更新的方法、系统及装置
CN113094761A (zh) * 2021-04-25 2021-07-09 中山大学 一种联邦学习数据防篡改监测方法及相关装置
CN113553377A (zh) * 2021-07-21 2021-10-26 湖南天河国云科技有限公司 基于区块链和联邦学习的数据共享方法及装置
CN113553377B (zh) * 2021-07-21 2022-06-21 湖南天河国云科技有限公司 基于区块链和联邦学习的数据共享方法及装置
CN114297722A (zh) * 2022-03-09 2022-04-08 广东工业大学 一种基于区块链的隐私保护异步联邦共享方法及系统
CN117351299A (zh) * 2023-09-13 2024-01-05 北京百度网讯科技有限公司 图像生成及模型训练方法、装置、设备和存储介质

Similar Documents

Publication Publication Date Title
CN111695696A (zh) 一种基于联邦学习的模型训练的方法及装置
CN112181666B (zh) 一种基于边缘智能的设备评估和联邦学习重要性聚合方法
CN111242282B (zh) 基于端边云协同的深度学习模型训练加速方法
CN113221183B (zh) 实现隐私保护的多方协同更新模型的方法、装置及系统
CN113033712B (zh) 一种基于联邦学习的多用户协同训练人流统计方法及系统
CN110689136B (zh) 一种深度学习模型获得方法、装置、设备及存储介质
US11651198B2 (en) Data processing method and apparatus for neural network
CN111723947A (zh) 一种联邦学习模型的训练方法及装置
US20220318412A1 (en) Privacy-aware pruning in machine learning
CN112948885B (zh) 实现隐私保护的多方协同更新模型的方法、装置及系统
CN113608881B (zh) 内存分配方法、装置、设备、可读存储介质及程序产品
CN113241064A (zh) 语音识别、模型训练方法、装置、电子设备和存储介质
CN111325340A (zh) 信息网络关系预测方法及系统
CN114580636A (zh) 基于三目标联合优化的神经网络轻量化部署方法
CN114595815A (zh) 一种面向传输友好的云-端协作训练神经网络模型方法
CN114970830A (zh) 一种面向数据并行分布式深度学习训练加速的灵活通信方法
CN109697511B (zh) 数据推理方法、装置及计算机设备
CN114444688A (zh) 神经网络的量化方法、装置、设备、存储介质及程序产品
CN109981361B (zh) 一种传播网络中感染源的确定方法及装置
CN117786416B (zh) 一种模型训练方法、装置、设备、存储介质及产品
CN117521737B (zh) 网络模型的转换方法、装置、终端及计算机可读存储介质
CN117494816B (zh) 基于计算单元部署的模型推理方法、装置、设备及介质
CN111330269B (zh) 应用难度调整和策略确定方法、装置、系统、设备及介质
CN115834247B (zh) 一种基于区块链的边缘计算信任评价方法
CN116306781A (zh) 一种基于神经网络模型的数据处理方法、装置和电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination