CN113011603A - 模型参数更新方法、装置、设备、存储介质及程序产品 - Google Patents
模型参数更新方法、装置、设备、存储介质及程序产品 Download PDFInfo
- Publication number
- CN113011603A CN113011603A CN202110287041.3A CN202110287041A CN113011603A CN 113011603 A CN113011603 A CN 113011603A CN 202110287041 A CN202110287041 A CN 202110287041A CN 113011603 A CN113011603 A CN 113011603A
- Authority
- CN
- China
- Prior art keywords
- model
- parameter
- local
- local iteration
- loss
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种模型参数更新方法、装置、设备、存储介质及程序产品,所述方法包括:计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值;采用所述梯度值更新所述参数,以完成本轮本地迭代。本发明实现了在通过增加本地迭代次数减少通信成本的同时,还能够保证模型的预测准确率。
Description
技术领域
本发明涉及机器学习技术领域,尤其涉及一种模型参数更新方法、装置、设备、存储介质及程序产品。
背景技术
随着人工智能的发展,人们为解决数据孤岛的问题,提出了“联邦学习”的概念,使得联邦双方在不用给出己方数据的情况下,也可进行模型训练得到模型参数,并且可以避免数据隐私泄露的问题。纵向联邦学习,纵向联邦学习是在参与者的数据特征重叠较小,而用户重叠较多的情况下,取出参与者用户相同而用户数据特征不同的那部分用户及数据进行联合训练机器学习模型。
在纵向联邦学习过程中,拥有标签数据的参与方需要与其他参与方之间进行多次通信,以传输对方更新参数时所需的中间结果,例如模型输出或模型输出对应的梯度值。双方需要进行多轮联合参数更新,也即需要进行多次通信,因此通信成本较大。针对这一问题,目前提出了参与方利用其他参与方发送的一次中间结果在本地进行多轮本地迭代的方案,通过增加本地迭代次数来减少联合更新参数的次数,从而减少通信成本。
但是,该方案中当参与方本地迭代次数较多时,容易出现参数失真的问题,导致模型的性能无法保障,而当本地迭代次数较少时,又无法有效地减少通信成本。
发明内容
本发明的主要目的在于提供一种模型参数更新方法、装置、设备、存储介质及程序产品,旨在目前纵向联邦学习方案中通信成本和模型性能难以兼顾的问题。
为实现上述目的,本发明提供一种模型参数更新方法,所述方法应用于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述方法包括以下步骤:
计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值;
采用所述梯度值更新所述参数,以完成本轮本地迭代。
可选地,所述计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量的步骤包括:
将所述第一设备中第一模型的参数在本轮本地迭代中的参数向量与在预设历史轮次的本地迭代中的参数向量进行对应元素相减,得到差向量;
计算所述差向量中各元素的平方和,基于所述平方和得到所述近端优化损失。
可选地,当所述第一设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中模型的输出,
所述基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值的步骤包括:
将所述第一设备的训练数据输入所述第一设备中的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
根据所述模型输出和所述纵向联邦中间结果计算得到预测结果,并基于所述预测结果和所述训练数据对应的标签数据计算得到预测损失;
将所述预测损失和所述近端优化损失相加得到总损失,基于所述总损失计算得到所述参数对应的梯度值。
可选地,当所述第二设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中预测损失相对于所述第一设备在本轮联合参数更新时发送的第一模型的输出的梯度值,
所述基于所述近端优化损失、所述模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值的步骤包括:
将所述第一设备的训练数据输入所述第一设备的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
根据所述模型输出和所述纵向联邦中间结果计算得到所述预测损失相对于所述参数的第一子梯度值;
计算所述近端优化损失相对于所述参数的第二子梯度值,将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值。
可选地,所述将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值的步骤包括:
将所述第二子梯度值乘以预设调节系数后加上所述第一子梯度值得到所述参数对应的梯度值。
为实现上述目的,本发明提供一种用户风险预测方法,所述方法应用于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述方法包括以下步骤:
基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型,其中,所述近端优化损失表征本端待训练模型的参数在当次本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
采用所述本端风险预测模型预测得到待预测用户的风险值。
可选地,所述基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型的步骤包括:
接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果;
基于近端优化损失和所述纵向联邦中间结果对所述本端待训练模型中的参数进行预设轮数的本地迭代更新;
检测更新参数后的本端待训练模型是否满足预设模型条件;
若满足,则将更新参数后的本端待训练模型作为所述本端风险预测模型;
若不满足,则返回执行所述接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果的步骤。
可选地,所述基于近端优化损失和所述纵向联邦中间结果对所述本端待训练模型中的参数进行预设轮数的本地迭代更新的步骤包括:
计算近端优化损失,并基于所述近端优化损失、所述本端待训练模型在本轮本地迭代中的模型输出以及所述纵向联邦中间结果,计算得到所述参数对应的梯度值;
采用所述梯度值更新所述参数以完成本轮本地迭代;
检测本地迭代轮数是否达到预设轮数;
若达到,则执行所述检测更新参数后的本端待训练模型是否满足预设模型条件的步骤;
若未达到,则返回执行所述计算近端优化损失的步骤,并将所述本地迭代轮数自增1。
为实现上述目的,本发明提供一种模型参数更新装置,所述装置部署于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述装置包括:
第一计算模块,用于计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
第二计算模块,用于基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值;
更新模块,用于采用所述梯度值更新所述参数,以完成本轮本地迭代。
为实现上述目的,本发明提供一种用户风险预测装置,所述装置部署于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述装置包括:
联邦学习模块,用于基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型,其中,所述近端优化损失表征本端待训练模型的参数在当次本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
预测模块,用于采用所述本端风险预测模型预测得到待预测用户的风险值。
为实现上述目的,本发明还提供一种模型参数更新设备,所述模型参数更新设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的模型参数更新程序,所述模型参数更新程序被所述处理器执行时实现如上所述的模型参数更新方法的步骤。
为实现上述目的,本发明还提供一种用户风险预测设备,所述用户风险预测设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的用户风险预测程序,所述用户风险预测程序被所述处理器执行时实现如上所述的用户风险预测方法的步骤。
此外,为实现上述目的,本发明还提出一种计算机可读存储介质,所述计算机可读存储介质上存储有模型参数更新程序,所述模型参数更新程序被处理器执行时实现如上所述的模型参数更新方法的步骤。
此外,为实现上述目的,本发明还提出一种计算机可读存储介质,所述计算机可读存储介质上存储有用户风险预测程序,所述用户风险预测程序被处理器执行时实现如上所述的用户风险预测方法的步骤。
此外,为实现上述目的,本发明还提出一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现如上所述的模型参数更新方法的步骤。
此外,为实现上述目的,本发明还提出一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现如上所述的用户风险预测方法的步骤。
相比于现有方案,在本发明中,参与纵向联邦学习的第一设备在进行本地迭代时,增加计算能够表征第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量的近端优化损失,并基于近端优化损失、第一模型在本轮本地迭代中的模型输出以及从第二设备接收到的纵向联邦中间结果计算第一模型中参数对应的梯度值,根据梯度值来更新参数,也即增加近端优化损失来约束第一模型的参数在本地迭代中的变化量,从而避免本地迭代时参数值变化过大导致失真,实现了在通过增加本地迭代次数减少通信成本的同时,还能够保证模型的预测准确率。
附图说明
图1为本发明实施例方案涉及的硬件运行环境的结构示意图;
图2为本发明模型参数更新方法第一实施例的流程示意图;
图3为本发明实施例涉及的一种参与方进行联合参数更新的示意图;
图4为本发明实施例涉及的一种第一设备和第二设备进行纵向联邦学习的硬件架构图;
图5为本发明实施例涉及的一种第一设备与第二设备进行多轮联合参数更新的交互流程示意图;
图6为本发明模型参数更新装置较佳实施例的功能示意图模块图。
本发明目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
如图1所示,图1是本发明实施例方案涉及的硬件运行环境的设备结构示意图。
需要说明的是,本发明实施例模型参数更新设备可以是智能手机、个人计算机和服务器等设备,在此不做具体限制,模型参数更新设备可以是参与纵向联邦学习的第一设备。
如图1所示,该模型参数更新设备可以包括:处理器1001,例如CPU,网络接口1004,用户接口1003,存储器1005,通信总线1002。其中,通信总线1002用于实现这些组件之间的连接通信。用户接口1003可以包括显示屏(Display)、输入单元比如键盘(Keyboard),可选用户接口1003还可以包括标准的有线接口、无线接口。网络接口1004可选的可以包括标准的有线接口、无线接口(如WI-FI接口)。存储器1005可以是高速RAM存储器,也可以是稳定的存储器(non-volatile memory),例如磁盘存储器。存储器1005可选的还可以是独立于前述处理器1001的存储装置。
本领域技术人员可以理解,图1中示出的设备结构并不构成对模型参数更新设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
如图1所示,作为一种计算机存储介质的存储器1005中可以包括操作系统、网络通信模块、用户接口模块以及模型参数更新程序。其中,操作系统是管理和控制设备硬件和软件资源的程序,支持模型参数更新程序以及其它软件或程序的运行。在图1所示的设备中,用户接口1003主要用于与客户端进行数据通信;网络接口1004主要用于与参与纵向联邦学习的第二设备建立通信连接;处理器1001可以用于调用存储器1005中存储的模型参数更新程序,并执行以下操作:
计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值;
采用所述梯度值更新所述参数,以完成本轮本地迭代。
进一步地,所述计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量的步骤包括:
将所述第一设备中第一模型的参数在本轮本地迭代中的参数向量与在预设历史轮次的本地迭代中的参数向量进行对应元素相减,得到差向量;
计算所述差向量中各元素的平方和,基于所述平方和得到所述近端优化损失。
进一步地,当所述第一设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中模型的输出,
所述基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值的步骤包括:
将所述第一设备的训练数据输入所述第一设备中的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
根据所述模型输出和所述纵向联邦中间结果计算得到预测结果,并基于所述预测结果和所述训练数据对应的标签数据计算得到预测损失;
将所述预测损失和所述近端优化损失相加得到总损失,基于所述总损失计算得到所述参数对应的梯度值。
进一步地,当所述第二设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中预测损失相对于所述第一设备在本轮联合参数更新时发送的第一模型的输出的梯度值,
所述基于所述近端优化损失、所述模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值的步骤包括:
将所述第一设备的训练数据输入所述第一设备的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
根据所述模型输出和所述纵向联邦中间结果计算得到所述预测损失相对于所述参数的第一子梯度值;
计算所述近端优化损失相对于所述参数的第二子梯度值,将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值。
进一步地,所述将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值的步骤包括:
将所述第二子梯度值乘以预设调节系数后加上所述第一子梯度值得到所述参数对应的梯度值。
本发明实施例还提出一种用户风险预测设备,所述用户风险预测设备是参与纵向联邦学习的第一设备,第一设备与参与纵向联邦学习的第二设备建立通信连接,所述用户风险预测设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的用户风险预测程序,所述用户风险预测程序被所述处理器执行时实现如下步骤:
基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型,其中,所述近端优化损失表征本端待训练模型的参数在当次本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
采用所述本端风险预测模型预测得到待预测用户的风险值。
进一步地,所述基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型的步骤包括:
接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果;
基于近端优化损失和所述纵向联邦中间结果对所述本端待训练模型中的参数进行预设轮数的本地迭代更新;
检测更新参数后的本端待训练模型是否满足预设模型条件;
若满足,则将更新参数后的本端待训练模型作为所述本端风险预测模型;
若不满足,则返回执行所述接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果的步骤。
进一步地,所述基于近端优化损失和所述纵向联邦中间结果对所述本端待训练模型中的参数进行预设轮数的本地迭代更新的步骤包括:
计算近端优化损失,并基于所述近端优化损失、所述本端待训练模型在本轮本地迭代中的模型输出以及所述纵向联邦中间结果,计算得到所述参数对应的梯度值;
采用所述梯度值更新所述参数以完成本轮本地迭代;
检测本地迭代轮数是否达到预设轮数;
若达到,则执行所述检测更新参数后的本端待训练模型是否满足预设模型条件的步骤;
若未达到,则返回执行所述计算近端优化损失的步骤,并将所述本地迭代轮数自增1。
基于上述的结构,提出模型参数更新方法的各实施例。
参照图2,图2为本发明模型参数更新方法第一实施例的流程示意图。需要说明的是,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。本发明模型参数更新方法应用于参与纵向联邦学习的第一设备,第一设备与参与纵向联邦学习的第二设备通信连接,第一设备和第二设备可以是智能手机、个人计算机和服务器等设备。在本实施例中,模型参数更新方法包括:
步骤S10,计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
在本实施例中,纵向联邦学习中的参与方分为两类,一类是拥有标签数据的数据应用参与方,一类是没有标签数据的数据提供参与方,一般情况下,数据应用参与方有一个,数据提供参与方有一个或多个。各个参与方分别部署有基于各自数据特征构建的数据集和机器学习模型,各个参与方的机器学习模型组合起来构成一个完整的模型,用于完成预测或分类等模型任务。其中,各个参与方的数据集的样本维度是对齐的,也即,各个数据集的样本ID是相同的,但是各个参与方的数据特征可各不相同。各个参与方预先可采用加密样本对齐的方式来构建样本维度对齐的数据集,在此不进行详细赘述。
参与方部署的机器学习模型可以是普通的机器学习模型,例如线性回归模型、神经网络模型等,也可以是自动机器学习中使用的模型,例如搜索网络。搜索网络是指用于进行模型参数更新(NAS)的网络;搜索网络中包括多个单元,每个单元对应一个网络层,其中部分单元之间设置有连接操作,以其中两个单元为例,这两个单元之前的连接操作可以是预先设置的N种连接操作,并定义了每种连接操作对应的权重,该权重即搜索网络的结构参数,单元内的网络层参数即搜索网络的模型参数;在模型训练过程中,需要进行模型参数更新以优化更新结构参数和模型参数,基于最终更新的结构参数即可确定最终的网络结构,即确定保留哪个或哪些连接操作。由于该网络的结构是经过网络搜索之后才确定的,各个参与方不需要像设计传统纵向联邦学习的模型一样去设置模型的网络结构,从而降低了设计模型的难度。
在本实施例中,第一设备可以是拥有标签数据的数据应用参与方,对应地第二设备是没有标签数据的数据提供参与方,此时第二设备可以有多个;或者,第一设备也可以是没有标签数据的数据提供参与方,对应地第二设备是拥有标签数据的数据应用参与方。未示区分,以下将第一设备中的模型称为第一模型,将第二设备中的模型称为第二模型。
各参与方中模型的参数预先进行初始化设置,各个参与方进行多轮联合参数更新,以不断更新各自模型中的参数,提高整个模型的性能,如预测准确率。当各个参与方的模型是普通的机器学习模型时,各轮联合参数更新过程中所更新的参数为模型参数,例如神经网络中的权重参数。当各个参与方的模型是搜索网络时,各轮联合参数更新时所更新的参数可以是结构参数和/或模型参数。具体地,在本实施例中并不限制对结构参数和模型参数的更新顺序。例如,可以在前几轮联合参数更新时对结构参数进行更新,后几轮联合参数更新时对模型参数进行更新。又如,可以每一轮联合参数更新时都是对结构参数和模型参数一起更新。
在一轮联合更新参数的过程中,各参与方先交互用于更新各自模型中参数的中间结果(以下也称为纵向联邦中间结果);各个参与方分别根据接收到的中间结果在本地进行多轮本地迭代,本地迭代后,再进行下一轮联合参数更新。也即,在一轮联合参数更新过程中,参与方只接收到其他参与方发送的一次中间结果,后续的多轮本地迭代都采用该中间结果来参与计算。其中,中间结果可以是梯度,也可以是模型的输出。具体地,当参与方是数据提供参与方时,发送给对方的中间结果可以是该参与方中模型的输出;当参与方是数据应用参与方时,发送给对方的中间结果可以是计算得到的数据提供方所发送模型输出对应的梯度。由于传递的是中间结果而不是数据集中的原始数据,使得各个参与方互相之间并没有泄露各自的数据隐私,保护了各个参与方的数据安全。如图3所示,为一实施例中参与方进行联合参数更新的示意图,其中,Party K是数据应用参与方,Party 1~Party K-1是数据提供参与方,NetK是数据应用参与方中部署的模型,Netj是数据提供参与方中部署的模型,Netc是数据应用参与方中部署的用于基于各方模型输出计算预测结果(Yout)的模型,Nj是模型的输出,是模型输出对应的梯度值。
第一设备在进行一轮本地迭代时,可计算近端优化损失。近端优化损失能够表征第一模型的参数在本轮本地迭代中的参数值与预设历史轮次的本地迭代中参数值之间的变化量。通过最小化近端优化损失,就可以约束本轮本地迭代中第一模型的参数值相比于前面历史参数值的变化幅度,也即,使得第一模型的参数值在本轮本地迭代时变化较小,从而避免在进行较多轮次的本地迭代后参数值失真。在本实施例中,对近端优化损失的计算方式并不做限制,对近端优化损失的最小化方法也不做限制。例如,在一实施方式中,可将近端优化损失作为损失函数,通过最小化损失函数的方法最小化近端优化损失,例如通过梯度下降算法计算近端优化损失相对于第一模型中参数的梯度值,根据梯度值优化参数,进而达到最小化近端优化损失的目的。在其他实施方式中,也可以采用其他的方法来最小化近端优化损失,例如,通过随机改变第一模型中参数的参数值,计算近端损失函数是否变小,随机试验获得最小化近端优化损失的参数值。
其中,预设历史轮次可以是在第一设备中预先设置的一个轮次,该轮次早于本轮本地迭代的轮次,如本轮本地迭代是第t轮,则预设历史轮次小于t。在一轮联合参数更新中,预设历史轮次可以是固定的,也即,该轮联合参数更新中的各轮本地迭代,都以同一历史轮次的本地迭代时的参数值为基础计算近端优化损失,例如预设历史轮次固定设置为1,以保证后面各轮本地迭代时参数值相比于第一轮本地迭代时的参数值变化量较小。或者,在一轮联合参数更新中,预设历史轮次也可以不是固定的,也即,针对该轮联合参数更新中的各轮本地迭代,可设置不同的预设历史轮次;当针对各轮本地迭代单独设置历史预设轮次时,若以该轮本地迭代的上一轮本地迭代为基础计算近端优化损失,则依据该近端优化损失计算出的参数的梯度值可能为0,也即近端优化损失没有起到约束作用,因此,在一优选实施方式中,对于全部轮次或部分轮次的本地迭代,该轮本地迭代的轮次减去其对应的预设历史轮次应当大于1。需要说明的是,在一些实施方式中,第一设备可以不用每一轮本地迭代都计算近端优化损失,例如,第一轮本地迭代时没有历史轮次的本地迭代,故无需计算近端优化损失。
进一步地,在一实施方式中,所述步骤S10包括:
步骤S101,将所述第一设备中第一模型的参数在本轮本地迭代中的参数向量与在预设历史轮次的本地迭代中的参数向量进行对应元素相减,得到差向量;
步骤S102,计算所述差向量中各元素的平方和,基于所述平方和得到所述近端优化损失。
第一模型的参数有多个,可采用向量的形式来表示。第一设备可将第一模型的参数在本轮本地迭代中的参数向量与在预设历史轮次的本地迭代中的参数向量进行对应元素相减,得到由各个差值构成的差向量,再计算差向量中各个元素的平方和。第一设备可将平方和直接作为近端优化损失,也可以计算平方和的平方根后作为近端优化损失。需要说明的是,在计算近端优化损失的过程中,第一设备是将本轮本地迭代中的参数向量的各个元素作为未知变量来参与计算。
在其他实施方式中,第一设备也可以采用其他能够计算向量之间变化量的计算方法来计算近端优化损失。
步骤S20,基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值;
第一设备将数据集中的训练数据输入到第一模型中,经过第一模型处理得到模型输出,根据近端优化损失、该模型输出和从第二设备接收到的纵向联邦中间结果,计算得到第一模型中各个参数对应的梯度值。其中,从第二设备接收到的纵向联邦中间结果即在本轮联合参数更新时第二设备发送的中间结果。具体地,当第一设备是拥有标签数据的数据应用参与方时,纵向联邦中间结果是第二设备发送的第二模型的输出,第一设备可根据纵向联邦中间结果和本轮本地迭代中的模型输出计算预测损失;在一实施方式中,第一设备可将近端优化损失和预测损失相加,得到一个总损失,再计算总损失相对于第一模型中参数的梯度值;在另一实施方式中,第一设备分别计算近端优化损失和预测损失相对于第一模型中参数的梯度值,将两个梯度值将加,得到最终的梯度值。当第一设备是没有标签数据的数据提供参与方,第二设备是拥有标签数据的数据应用参与方时,纵向联邦中间结果是第二设备计算的预测损失相对于第一模型的输出的梯度值;第一设备可以根据纵向联邦中间结果和本轮本地迭代中的模型输出计算出预测损失相对于第一模型中参数的梯度值,再计算近端优化损失相对于第一模型中参数的梯度值,将两个梯度值相加得到最终的梯度值。需要说明的是,在本发明各实施例中,根据损失计算梯度值的方法可参照现有的梯度计算方法,不作详细赘述。
步骤S30,采用所述梯度值更新所述参数,以完成本轮本地迭代。
第一设备在计算得到第一模型中各个参数的梯度值后,采用各个梯度值对各个参数进行更新。也即,每一个参数对应一个梯度值,第一设备采用该参数对应的梯度值来更新该参数。具体地,第一设备可将该参数在上一轮本地迭代更新后的参数值加上该参数对应的梯度值乘以学习率,得到该参数在本轮本地迭代更新后的参数值。对各个参数进行更新后,即完成了本轮本地迭代。通过在计算参数的梯度值时增加近端优化损失,再根据梯度值来更新参数,能够使得参数朝着最小化近端优化损失的方向变化,从而约束参数的变化量,避免参数值变化过大而失真。
进一步地,第一设备在完成本轮本地迭代后,若检测到达到了本轮联合参数更新的本地迭代轮数,则可以进行下一轮联合参数更新;若检测到还未达到本轮联合参数更新的本地迭代轮数,则可以进行下一轮本地迭代。在一实施方式中,可以设置一个联合更新参数的最大轮数,当达到该轮数时,第一设备停止对模型参数的更新。或,在另一实施方式中,第一设备可在一轮联合参数更新结束后,或者在一轮本地迭代结束后,检测预测损失是否收敛,若收敛,则停止对参数的更新。在停止对参数更新后,第一设备将当前参数值作为第一模型最终的参数值,确定第一模型的参数值后,即可采用第一模型完成预测任务。
如图4所示为一实施方式中参与纵向联邦学习的第一设备和第二设备的硬件架构图,第一设备和第二设备交互中间结果,基于对方发送的中间结果,各自在本地进行多轮本地迭代,在各轮本地迭代时,增加计算近端优化损失来约束参数的变化量,避免参数值变化过大而失真。
相比于现有方案,在本实施例中,参与纵向联邦学习的第一设备在进行本地迭代时,增加计算能够表征第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量的近端优化损失,并基于近端优化损失、第一模型在本轮本地迭代中的模型输出以及从第二设备接收到的纵向联邦中间结果计算第一模型中参数对应的梯度值,根据梯度值来更新参数,也即增加近端优化损失来约束第一模型的参数在本地迭代中的变化量,从而避免本地迭代时参数值变化过大导致失真,实现了在通过增加本地迭代次数减少通信成本的同时,还能够保证模型的预测准确率。
进一步地,基于上述第一实施例,提出本发明模型参数更新方法第二实施例,在本实施例中,当所述第一设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中模型的输出,所述步骤S20包括:
步骤S201,将所述第一设备的训练数据输入所述第一设备中的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
在本实施例中,第一设备是拥有标签数据的数据应用参与方,第二设备是没有标签数据的数据提供参与方时,纵向联邦学习中间结果是在本轮联合参数更新时第二设备中将其训练数据输入到第二模型进行处理得到的输出。在本轮联合参数更新的一轮本地迭代中,第一设备可将其训练数据输入到第一模型进行处理,得到第一模型在本轮本地迭代中的模型输出。
步骤S202,根据所述模型输出和所述纵向联邦中间结果计算得到预测结果,并基于所述预测结果和所述训练数据对应的标签数据计算得到预测损失;
第一设备根据模型输出和纵向联邦中间结果计算得到预测结果,并基于预测结果和训练数据对应的标签数据计算得到预测损失。具体地,根据机器学习模型的类型不同,预测结果的计算方式不同;例如,当纵向联邦学习的机器学习模型是线性回归模型时,第一设备将模型输出和纵向联邦中间结果相加得到预测结果;又如,当纵向联邦学习的机器学习模型是神经网络模型时,第一设备中第一模型包括如图3所示的NetK和Netc两部分,第一设备将训练数据输入到第一模型的NetK部分进行处理,得到模型输出NK,再将NK和纵向联邦中间结果Nj输入到Netc部分进行处理,得到预测结果Yout。
第一设备根据预测结果和训练数据对应的标签数据计算得到预测损失。其中,预测损失可采用常用的损失函数计算方法计算得到,例如交叉熵损失函数,根据所训练的机器学习模型不同,可采用不同的损失函数。
步骤S203,将所述预测损失和所述近端优化损失相加得到总损失,基于所述总损失计算得到所述参数对应的梯度值。
第一设备将预测损失和近端优化损失相加得到总损失。具体地,第一设备可以将两个损失直接相加,也可以将两个损失加权求和,两个损失的权重可以根据需要进行设置。在一实施方式中,第一设备可以将预测损失加上近端优化损失与一调节系数的乘积得到总损失,其中,调节系数可以预先设置,并在各轮本地迭代时进行灵活调整,例如,在一轮联合参数更新中,调节系数可先初始化为0.1,再随着本地迭代的轮次增大而增大,以实现在本地迭代轮次越大时,对参数的变化量约束力度越大。第一设备基于总损失计算得到参数对应的梯度值,具体计算过程在此不做详细赘述。
进一步地,在其他实施方式中,第一设备在计算得到预测损失和近端优化损失后,可计算预测损失相对于第一模型的参数的梯度值,再计算近端优化损失相对于第一模型的参数的梯度值,再将两个梯度值相加,或者加权求和,得到参数对应的梯度值。
如图5所示,为一实施方式中第一设备与第二设备联合进行多轮联合参数更新的交互流程示意图。
在本实施例中,第一设备通过预测损失和近端优化损失来计算第一模型的参数对应的梯度值,再根据梯度值更新参数,实现了往最小化预测损失和近端优化损失的方向更新参数,不仅提高了模型的预测准确率,还约束了参数的变化量,避免参数变化过大而失真,从而在保证通过增加本地迭代次数来减少通信成本的同时,还能够保证模型的预测准确率。
进一步地,基于上述第一和/或第二实施例,提出本发明模型参数更新方法第三实施例,在本实施例中,当所述第二设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中预测损失相对于所述第一设备在本轮联合参数更新时发送的第一模型的输出的梯度值,所述步骤S20包括:
步骤S204,将所述第一设备的训练数据输入所述第一设备的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
在本实施例中,当第一设备是没有标签数据的数据提供参与方,第二设备是拥有标签数据的数据应用参与方时,在本轮联合参数更新时第一设备将训练数据输入到第一模型进行处理得到输出,将该输出作为中间结果发送给第二设备,第二设备计算出预测损失相对于该输出对应的梯度值,作为中间结果发送给第一设备,该中间结果即纵向联邦中间结果。在本轮联合参数更新的一轮本地迭代中,第一设备可将其训练数据输入到第一模型进行处理,得到第一模型在本轮本地迭代中的模型输出。在一实施例中,如图3所示,第一设备将训练数据输入Netj进行处理,得到模型输出Nj。
步骤S205,根据所述模型输出和所述纵向联邦中间结果计算得到所述预测损失相对于所述参数的第一子梯度值;
第一设备根据模型输出和纵向联邦中间结果计算得到预测损失相对于第一模型中参数的梯度值(未示区别以下称为第一子梯度值)。第一设备根据纵向联邦中间结果和模型输出,按照反向传播方法计算得到第一模型中各个参数对应的第一子梯度值。具体可按照如下公式计算出第一子梯度值:
其中,w是第一模型中的参数,Nj是本轮联合参数更新时发送给第二设备的中间结果(也即在本地迭代之前第一模型的模型输出),G(Nj)是第二设备返回预测损失相对于Nj的梯度值,也即纵向联邦中间结果,Nb是本轮本地迭代时第一模型的模型输出。
步骤S206,计算所述近端优化损失相对于所述参数的第二子梯度值,将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值。
第一设备在计算得到近端优化损失后,计算近端优化损失相对于第一模型中参数的梯度值(未示区别以下称为第二子梯度值)。第一设备将参数的第一子梯度值和第二子梯度值相加,即可得到参数对应的梯度值。具体地,当有多个参数时,每个参数分别有对应的第一子梯度值和第二子梯度值,将每个参数各自的第一子梯度值和第二子梯度值相加,得到各个参数分别对应的梯度值。
在本实施例中,第一设备通过预测损失和近端优化损失来计算第一模型的参数对应的梯度值,再根据梯度值更新参数,实现了往最小化预测损失和近端优化损失的方向更新参数,不仅提高了模型的预测准确率,还约束了参数的变化量,避免参数变化过大而失真,从而在保证通过增加本地迭代次数来减少通信成本的同时,还能够保证模型的预测准确率。
进一步地,所述步骤S206中将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值的步骤包括:
步骤S2061,将所述第二子梯度值乘以预设调节系数后加上所述第一子梯度值得到所述参数对应的梯度值。
在一实施方式中,第一设备中可以设置一个调节系数来调节各轮本地迭代时对参数变化量的约束力度。具体地,第一设备可以将第二子梯度值乘以该调节系数后再加上第一子梯度值,得到参数对应的梯度值。第一设备可以根据本地迭代的轮次来调整调节系数,例如,在一轮联合参数更新中,调节系数可先初始化为0.1,再随着本地迭代的轮次增大而增大,以实现在本地迭代轮次越大时,对参数的变化量约束力度越大。
进一步地,基于上述第一、第二和/或第三实施例,提出本发明用户风险预测方法第四实施例,在本实施例中,所述方法应用于参与纵向联邦学习的第一设备,第一设备与参与纵向联邦学习的第二设备通信连接,第一设备和第二设备可以是智能手机、个人计算机和服务器等设备。所述用户风险预测方法包括以下步骤:
步骤A10,基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型,其中,所述近端优化损失表征本端待训练模型的参数在当次本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
第一设备可以是数据应用参与方,也可以是数据提供参与方。第一设备中部署有基于各个用户在第一数据特征下的数据构建的第一数据集和第一模型(以下也称本端待训练模型),第二设备中部署有基于各个用户在第二数据特征下的数据构建的第二数据集和第二模型(以下也称他端待训练模型),两个数据集的用户维度相同;第一数据特征和第二数据特征是与预测用户风险相关的数据特征,且第一数据特征与第二数据特征不相同;第一模型和第二模型是一个完成机器学习模型的两个部分,具体可以根据需要选取常用的机器学习模型来实现,例如线性回归模型或神经网络模型,模型的预测结果设置为能够表征用户风险程度的数据形式,例如风险值;第一设备和第二设备联合采用第一数据集和第二数据集来训练第一模型和第二模型,训练完成后,可采用两个模型联合预测用户的风险。其中,风险可以是用户贷款前的信用风险,用户贷款中的拖欠还款风险等。例如,在一实施方式中,第一设备是部署于银行的设备,第一数据特征是银行业务相关的特征,例如用户历史贷款次数、用户历史违约次数等;第二设备是部署于电商的设备,第二数据特征是电商业务相关的特征,例如用户的历史购买次数、金额等,第一设备和第二设备采用各自的数据集进行纵向联邦学习,训练用于预测贷款前信用风险的模型。
具体地,第一设备基于近端优化损失与第二设备联合进行纵向联邦学习得到本端风险预测模型。具体地,第一设备可以按照上述第一、第二或第三实施例中的模型参数更新方法进行各轮联合参数更新中的各轮本地迭代,以更新第一模型中的参数,在此不进行详细赘述。在进行多轮联合参数更新后,第一设备将最终更新参数后的第一模型作为本端风险预测模型。
步骤A20,采用所述本端风险预测模型预测得到待预测用户的风险值。
第一设备在得到本端风险预测模型后,可以采用本端风险预测模型预测待预测用户的风险值。具体地,第二设备也按照上述实施例中的模型参数更新方法进行各轮联合参数更新中的各轮本地迭代,以更新第二模型中的参数,在进行多轮联合参数更新后,第二设备将最终更新参数后的第二模型作为他端风险预测模型(其中,他端是指第二设备);第一设备可以采用本端风险预测模型,联合第二设备中的他端风险预测模型进行预测得到待预测用户的风险值。其中,风险值可以是一个表示用户的风险程度大小的值。
在一实施方式中,第二设备可以将他端风险预测模型发送给第一设备,第一设备采用将待预测用户在第一数据特征下的用户数据输入到本端风险预测模型得到一个模型输出,再将待预测用户在第二数据特征下的用户数据输入到他端风险预测模型得到一个模型输出,根据两个模型输出得到待预测用户的风险值,例如,将两个模型直接相加。在另一实施方式中,若第一设备是拥有标签数据的数据应用参与方,则第一设备将待预测用户在第一数据特征下的用户数据输入到本端风险预测模型得到一个模型输出;第二设备将待预测用户在第二数据特征下的用户数据输入到他端风险预测模型得到一个模型输出,并发送给第一设备;第一设备根据两个模型输出计算得到待预测用户的风险值,例如第一设备的本端风险预测模型包括如图3所示的NetK和Netc两部分时,第一设备将各个模型输出输入到Netc部分进行处理,得到待预测用户的风险值。在另一实施方式中,若第一设备是没有标签数据的数据提供参与方,第二设备是拥有标签数据的数据应用参与方,则第一设备将待预测用户在第一数据特征下的用户数据输入到本端风险预测模型得到一个模型输出,并将该模型输出发送给第二设备;第二设备将待预测用户在第二数据特征下的用户数据输入到他端风险预测模型得到一个模型输出,第二设备根据两个模型输出计算得到待预测用户的风险值,并将风险值返回给第一设备。
在本实施例中,通过第一设备在与第二设备进行纵向联邦学习的过程中,增加能够表征第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量的近端优化损失,通过近端算函数来约束第一模型的参数在本地迭代中的变化量,避免本地迭代时参数值变化过大导致失真,从而实现了在减少用户风险预测时通信成本的同时,还能给保证用户风险预测的准确度。
进一步地,在一实施方式中,所述步骤A10包括:
步骤A101,接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果;
第一设备在进行一轮联合参数更新时,第一设备接收第二设备发送的本轮联合参数更新的纵向联邦中间结果。具体地,若第一设备是拥有标签数据的数据应用参与方,则第二设备在本轮联合参数更新中将其训练数据输入第二模型进行处理得到模型的输出,将该输出作为中间结果发送给第一设备,该中间结果即纵向联邦中间结果。若第一设备是没有标签数据的数据提供参与方,则第一设备将其训练数据输入到第一模型进行处理得到模型的输出,将该输出作为中间结果发送给第二设备,第二设备计算预测损失相对于该输出的梯度值,将梯度值作为中间结果发送给第一设备,该中间结果即纵向联邦中间结果。
步骤A102,基于近端优化损失和所述纵向联邦中间结果对所述本端待训练模型中的参数进行预设轮数的本地迭代更新;
第一设备基于近端优化损失和纵向联邦中间结果对本端待训练模型中的参数进行预设轮数的本地迭代更新,具体可以参照上述第一、第二或第三实施例中的模型参数更新方法对本端待训练模型进行各轮本地迭代,在此不进行详细赘述。其中,预设轮数可以是预先根据需要设置的一个数量。
步骤A103,检测更新参数后的本端待训练模型是否满足预设模型条件;
第一设备在进行预设轮数的本地迭代后,检测更新参数后的本端待训练模型是否满足预设模型条件。其中,预设模型条件可以是预先设置的一个条件,例如预测损失收敛,又如联合参数更新的轮次达到一个预定的轮次,或联合参数更新的时长达到一个预定的时长。
步骤A104,若满足,则将更新参数后的本端待训练模型作为所述本端风险预测模型;
若检测到满足预设模型条件,则第一设备可以将更新参数后的本端待训练模型作为本端风险预测模型。对应地,第二设备将更新参数后的他端待训练模型作为他端风险预测模型。
步骤A105,若不满足,则返回执行所述接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果的步骤。
若检测到不满足预设模型条件,则第一设备再返回到上述步骤A101,也即,进行下一轮联合参数更新。
进一步地,在一实施方式中,所述步骤A102包括:
步骤A1021,计算近端优化损失,并基于所述近端优化损失、所述本端待训练模型在本轮本地迭代中的模型输出以及所述纵向联邦中间结果,计算得到所述参数对应的梯度值;
步骤A1022,采用所述梯度值更新所述参数以完成本轮本地迭代;
具体地,第一设备可以按照上述第一实施例中步骤S10和步骤S20的具体实施过程来计算近端优化损失以及各个参数对应的梯度值,并按照上述步骤S30的具体实施过程来根据梯度值更新参数,在本实施例中不作详细赘述。
步骤A1023,检测本地迭代轮数是否达到预设轮数;
步骤A1024,若达到,则执行所述检测更新参数后的本端待训练模型是否满足预设模型条件的步骤;
步骤A1025,若未达到,则返回执行所述计算近端优化损失的步骤,并将所述本地迭代轮数自增1。
第一设备在完成一轮本地迭代后,检测当前本地迭代轮数是否达到了预设轮数;若达到,则第一设备执行步骤A103;若未达到,则将本地迭代轮数自增1,并返回至步骤A1021,也即进行下一轮本地迭代。
此外本发明实施例还提出一种模型参数更新装置,参照图6,所述装置部署于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述装置包括:
第一计算模块10,用于计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
第二计算模块20,用于基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值;
更新模块30,用于采用所述梯度值更新所述参数,以完成本轮本地迭代。
进一步地,所述第一计算模块10包括:
第一计算单元,用于将所述第一设备中第一模型的参数在本轮本地迭代中的参数向量与在预设历史轮次的本地迭代中的参数向量进行对应元素相减,得到差向量;
第二计算单元,用于计算所述差向量中各元素的平方和,基于所述平方和得到所述近端优化损失。
进一步地,当所述第一设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中模型的输出,所述第二计算模块20包括:
第一处理单元,用于将所述第一设备的训练数据输入所述第一设备中的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
第三计算单元,用于根据所述模型输出和所述纵向联邦中间结果计算得到预测结果,并基于所述预测结果和所述训练数据对应的标签数据计算得到预测损失;
第四计算单元,用于将所述预测损失和所述近端优化损失相加得到总损失,基于所述总损失计算得到所述参数对应的梯度值。
进一步地,当所述第二设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中预测损失相对于所述第一设备在本轮联合参数更新时发送的第一模型的输出的梯度值,
所述第二计算模块20包括:
第二处理单元,用于将所述第一设备的训练数据输入所述第一设备的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
第五计算单元,用于根据所述模型输出和所述纵向联邦中间结果计算得到所述预测损失相对于所述参数的第一子梯度值;
第六计算单元,用于计算所述近端优化损失相对于所述参数的第二子梯度值,将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值。
进一步地,所述第六计算单元还用于:
将所述第二子梯度值乘以预设调节系数后加上所述第一子梯度值得到所述参数对应的梯度值。
此外本发明实施例还提出一种用户风险预测装置,所述装置部署于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述装置包括:
联邦学习模块,用于基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型,其中,所述近端优化损失表征本端待训练模型的参数在当次本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
预测模块,用于采用所述本端风险预测模型预测得到待预测用户的风险值。
进一步地,所述联邦学习模块包括:
接收单元,用于接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果;
本地迭代单元,用于基于近端优化损失和所述纵向联邦中间结果对所述本端待训练模型中的参数进行预设轮数的本地迭代更新;
检测单元,用于检测更新参数后的本端待训练模型是否满足预设模型条件;
确定单元,用于若满足,则将更新参数后的本端待训练模型作为所述本端风险预测模型;
返回单元,用于若不满足,则返回执行所述接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果的步骤。
进一步地,所述本地迭代单元包括:
计算子单元,用于计算近端优化损失,并基于所述近端优化损失、所述本端待训练模型在本轮本地迭代中的模型输出以及所述纵向联邦中间结果,计算得到所述参数对应的梯度值;
更新子单元,用于采用所述梯度值更新所述参数以完成本轮本地迭代;
检测子单元,用于检测本地迭代轮数是否达到预设轮数;
执行子单元,用于若达到,则执行所述检测更新参数后的本端待训练模型是否满足预设模型条件的步骤;
返回子单元,用于若未达到,则返回执行所述计算近端优化损失的步骤,并将所述本地迭代轮数自增1。
此外,本发明实施例还提出一种计算机可读存储介质,所述存储介质上存储有模型参数更新程序,所述模型参数更新程序被处理器执行时实现如上所述的模型参数更新方法的步骤。本发明还提出一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现如上所述的模型参数更新方法的步骤。本发明模型参数更新设备、计算机可读存储介质和计算机产品的各实施例,均可参照本发明模型参数更新方法各实施例,此处不再赘述。
此外,本发明实施例还提出一种计算机可读存储介质,所述存储介质上存储有用户风险预测程序,所述用户风险预测程序被处理器执行时实现如上所述的用户风险预测方法的步骤。本发明还提出一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现如上所述的用户风险预测方法的步骤。本发明用户风险预测设备、计算机可读存储介质和计算机产品的各实施例,均可参照本发明用户风险预测方法各实施例,此处不再赘述。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本发明各实施例所述的方法。
以上仅为本发明的优选实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。
Claims (12)
1.一种模型参数更新方法,其特征在于,所述方法应用于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述方法包括以下步骤:
计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值;
采用所述梯度值更新所述参数,以完成本轮本地迭代。
2.如权利要求1所述的模型参数更新方法,其特征在于,所述计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量的步骤包括:
将所述第一设备中第一模型的参数在本轮本地迭代中的参数向量与在预设历史轮次的本地迭代中的参数向量进行对应元素相减,得到差向量;
计算所述差向量中各元素的平方和,基于所述平方和得到所述近端优化损失。
3.如权利要求1至2任一项所述的模型参数更新方法,其特征在于,当所述第一设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中模型的输出,
所述基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值的步骤包括:
将所述第一设备的训练数据输入所述第一设备中的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
根据所述模型输出和所述纵向联邦中间结果计算得到预测结果,并基于所述预测结果和所述训练数据对应的标签数据计算得到预测损失;
将所述预测损失和所述近端优化损失相加得到总损失,基于所述总损失计算得到所述参数对应的梯度值。
4.如权利要求1至2任一项所述的模型参数更新方法,其特征在于,当所述第二设备为拥有标签数据的参与方时,所述纵向联邦中间结果为所述第二设备中预测损失相对于所述第一设备在本轮联合参数更新时发送的第一模型的输出的梯度值,
所述基于所述近端优化损失、所述模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值的步骤包括:
将所述第一设备的训练数据输入所述第一设备的第一模型进行处理,得到所述第一模型在本轮本地迭代中的模型输出;
根据所述模型输出和所述纵向联邦中间结果计算得到所述预测损失相对于所述参数的第一子梯度值;
计算所述近端优化损失相对于所述参数的第二子梯度值,将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值。
5.如权利要求4所述的模型参数更新方法,其特征在于,所述将所述第一子梯度值和所述第二子梯度值相加得到所述参数对应的梯度值的步骤包括:
将所述第二子梯度值乘以预设调节系数后加上所述第一子梯度值得到所述参数对应的梯度值。
6.一种用户风险预测方法,其特征在于,所述方法应用于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述方法包括以下步骤:
基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型,其中,所述近端优化损失表征本端待训练模型的参数在当次本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
采用所述本端风险预测模型预测得到待预测用户的风险值。
7.如权利要求6所述的用户风险预测方法,其特征在于,所述基于近端优化损失与所述第二设备联合进行纵向联邦学习得到本端风险预测模型的步骤包括:
接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果;
基于近端优化损失和所述纵向联邦中间结果对所述本端待训练模型中的参数进行预设轮数的本地迭代更新;
检测更新参数后的本端待训练模型是否满足预设模型条件;
若满足,则将更新参数后的本端待训练模型作为所述本端风险预测模型;
若不满足,则返回执行所述接收所述第二设备发送的本轮联合参数更新的纵向联邦中间结果的步骤。
8.如权利要求6至7任一项所述的用户风险预测方法,其特征在于,所述基于近端优化损失和所述纵向联邦中间结果对所述本端待训练模型中的参数进行预设轮数的本地迭代更新的步骤包括:
计算近端优化损失,并基于所述近端优化损失、所述本端待训练模型在本轮本地迭代中的模型输出以及所述纵向联邦中间结果,计算得到所述参数对应的梯度值;
采用所述梯度值更新所述参数以完成本轮本地迭代;
检测本地迭代轮数是否达到预设轮数;
若达到,则执行所述检测更新参数后的本端待训练模型是否满足预设模型条件的步骤;
若未达到,则返回执行所述计算近端优化损失的步骤,并将所述本地迭代轮数自增1。
9.一种模型参数更新装置,其特征在于,所述装置部署于参与纵向联邦学习的第一设备,所述第一设备与参与纵向联邦学习的第二设备通信连接,所述装置包括:
第一计算模块,用于计算近端优化损失,其中,所述近端优化损失表征所述第一设备中第一模型的参数在本轮本地迭代中的参数值相比于在预设历史轮次的本地迭代中参数值的变化量;
第二计算模块,用于基于所述近端优化损失、所述第一模型在本轮本地迭代中的模型输出以及从所述第二设备接收到的纵向联邦中间结果,计算得到所述参数对应的梯度值;
更新模块,用于采用所述梯度值更新所述参数,以完成本轮本地迭代。
10.一种模型参数更新设备,其特征在于,所述模型参数更新设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的模型参数更新程序,所述模型参数更新程序被所述处理器执行时实现如权利要求1至5中任一项所述的模型参数更新方法的步骤。
11.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有模型参数更新程序,所述模型参数更新程序被处理器执行时实现如权利要求1至5中任一项所述的模型参数更新方法的步骤。
12.一种计算机程序产品,包括计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至5中任一项所述的模型参数更新方法的步骤。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110287041.3A CN113011603A (zh) | 2021-03-17 | 2021-03-17 | 模型参数更新方法、装置、设备、存储介质及程序产品 |
PCT/CN2021/094936 WO2022193432A1 (zh) | 2021-03-17 | 2021-05-20 | 模型参数更新方法、装置、设备、存储介质及程序产品 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110287041.3A CN113011603A (zh) | 2021-03-17 | 2021-03-17 | 模型参数更新方法、装置、设备、存储介质及程序产品 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113011603A true CN113011603A (zh) | 2021-06-22 |
Family
ID=76409316
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110287041.3A Pending CN113011603A (zh) | 2021-03-17 | 2021-03-17 | 模型参数更新方法、装置、设备、存储介质及程序产品 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN113011603A (zh) |
WO (1) | WO2022193432A1 (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114330759A (zh) * | 2022-03-08 | 2022-04-12 | 富算科技(上海)有限公司 | 一种纵向联邦学习模型的训练方法及系统 |
CN116128072A (zh) * | 2023-01-20 | 2023-05-16 | 支付宝(杭州)信息技术有限公司 | 一种风险控制模型的训练方法、装置、设备及存储介质 |
WO2024036526A1 (zh) * | 2022-08-17 | 2024-02-22 | 华为技术有限公司 | 一种模型调度方法和装置 |
Families Citing this family (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116186782B (zh) * | 2023-04-17 | 2023-07-14 | 北京数牍科技有限公司 | 联邦图计算方法、装置及电子设备 |
CN116205313B (zh) * | 2023-04-27 | 2023-08-11 | 数字浙江技术运营有限公司 | 联邦学习参与方的选择方法、装置及电子设备 |
CN116610958A (zh) * | 2023-06-20 | 2023-08-18 | 河海大学 | 面向无人机群水库水质检测的分布式模型训练方法及系统 |
CN117151208B (zh) * | 2023-08-07 | 2024-03-22 | 大连理工大学 | 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质 |
CN117575291B (zh) * | 2024-01-15 | 2024-05-10 | 湖南科技大学 | 基于边缘参数熵的联邦学习的数据协同管理方法 |
Family Cites Families (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109754105B (zh) * | 2017-11-07 | 2024-01-05 | 华为技术有限公司 | 一种预测方法及终端、服务器 |
JP7276436B2 (ja) * | 2019-05-21 | 2023-05-18 | 日本電気株式会社 | 学習装置、学習方法、コンピュータプログラム及び記録媒体 |
CN111210003B (zh) * | 2019-12-30 | 2021-03-19 | 深圳前海微众银行股份有限公司 | 纵向联邦学习系统优化方法、装置、设备及可读存储介质 |
CN111242316B (zh) * | 2020-01-09 | 2024-05-28 | 深圳前海微众银行股份有限公司 | 纵向联邦学习模型训练优化方法、装置、设备及介质 |
CN111860864A (zh) * | 2020-07-23 | 2020-10-30 | 深圳前海微众银行股份有限公司 | 纵向联邦建模优化方法、设备及可读存储介质 |
-
2021
- 2021-03-17 CN CN202110287041.3A patent/CN113011603A/zh active Pending
- 2021-05-20 WO PCT/CN2021/094936 patent/WO2022193432A1/zh active Application Filing
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114330759A (zh) * | 2022-03-08 | 2022-04-12 | 富算科技(上海)有限公司 | 一种纵向联邦学习模型的训练方法及系统 |
CN114330759B (zh) * | 2022-03-08 | 2022-08-02 | 富算科技(上海)有限公司 | 一种纵向联邦学习模型的训练方法及系统 |
WO2024036526A1 (zh) * | 2022-08-17 | 2024-02-22 | 华为技术有限公司 | 一种模型调度方法和装置 |
CN116128072A (zh) * | 2023-01-20 | 2023-05-16 | 支付宝(杭州)信息技术有限公司 | 一种风险控制模型的训练方法、装置、设备及存储介质 |
CN116128072B (zh) * | 2023-01-20 | 2023-08-25 | 支付宝(杭州)信息技术有限公司 | 一种风险控制模型的训练方法、装置、设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
WO2022193432A1 (zh) | 2022-09-22 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113011603A (zh) | 模型参数更新方法、装置、设备、存储介质及程序产品 | |
CN111860864A (zh) | 纵向联邦建模优化方法、设备及可读存储介质 | |
CN111222628B (zh) | 循环神经网络训练优化方法、设备、系统及可读存储介质 | |
CN112422644B (zh) | 计算任务卸载方法及系统、电子设备和存储介质 | |
CN113408743A (zh) | 联邦模型的生成方法、装置、电子设备和存储介质 | |
CN111860868A (zh) | 训练样本构建方法、装置、设备及计算机可读存储介质 | |
US11334758B2 (en) | Method and apparatus of data processing using multiple types of non-linear combination processing | |
CN112052960A (zh) | 纵向联邦建模方法、装置、设备及计算机可读存储介质 | |
CN111797999A (zh) | 纵向联邦建模优化方法、装置、设备及可读存储介质 | |
CN112948885B (zh) | 实现隐私保护的多方协同更新模型的方法、装置及系统 | |
CN113392971A (zh) | 策略网络训练方法、装置、设备及可读存储介质 | |
CN109635422B (zh) | 联合建模方法、装置、设备以及计算机可读存储介质 | |
CN114519435A (zh) | 模型参数更新方法、模型参数更新装置和电子设备 | |
CN112686370A (zh) | 网络结构搜索方法、装置、设备、存储介质及程序产品 | |
CN112861165A (zh) | 模型参数更新方法、装置、设备、存储介质及程序产品 | |
CN113592593B (zh) | 序列推荐模型的训练及应用方法、装置、设备及存储介质 | |
CN111475392A (zh) | 生成预测信息的方法、装置、电子设备和计算机可读介质 | |
CN114760308A (zh) | 边缘计算卸载方法及装置 | |
CN115034379A (zh) | 一种因果关系确定方法及相关设备 | |
CN111510473B (zh) | 访问请求处理方法、装置、电子设备和计算机可读介质 | |
CN113190872A (zh) | 数据保护方法、网络结构训练方法、装置、介质及设备 | |
CN112926090A (zh) | 基于差分隐私的业务分析方法及装置 | |
CN116306981A (zh) | 策略确定方法、装置、介质及电子设备 | |
CN110087230B (zh) | 数据处理方法、装置、存储介质及电子设备 | |
CN113822455A (zh) | 一种时间预测方法、装置、服务器及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |