发明内容
有鉴于此,本公开的目的在于提出一种联邦学习模型训练方法、电子设备及存储介质。
基于上述目的,本公开提供了一种联邦学习模型训练方法,包括:
任一参与方设备基于本方模型参数和特征信息与其他参与方设备进行联合加密训练,获得本方的梯度信息;
任一参与方设备基于模型参数和梯度信息获取模型参数变化量和梯度信息变化量,并基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得本方的梯度搜索方向作为拟牛顿条件;
目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息;其中,所述目标参与方设备为任一参与方设备中具有标签信息的参与方设备,所述模型损失函数为凸函数;
任一参与方设备基于所述梯度搜索方向、所述步长信息对本方的模型参数进行更新,直至所述联邦学习模型收敛。
所述任一参与方设备基于所述模型参数变化量和所述梯度信息变化量,采用双向循环递归方法与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向作为拟牛顿条件,包括:
任一参与方设备基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得中间变化量;所述中间变化量用于表征所述梯度信息的大小;
任一参与方设备基于所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。
可选的,所述任一参与方设备基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得中间变化量,还包括:
任一参与方设备基于本方的所述模型参数变化量和所述梯度信息变化量计算本方第一中间值信息,与其他参与方设备交换第一中间值信息并基于各参与方设备的第一中间值信息计算第一全局中间值,以根据所述第一全局中间值计算所述中间变化量。
可选的,所述第一中间值信息基于所述梯度信息变化量的转置矩阵与所述模型参数变化量的乘积获得。
可选的,所述任一参与方设备基于所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向,还包括:
任一参与方设备基于本方的所述中间变化量计算本方的第二中间值信息;
任一参与方设备基于本方的所述第二中间值信息,与其他参与方设备交换第二中间值信息并基于各参与方设备的第二中间值信息计算第二全局中间值,以根据所述第二全局中间值计算所述梯度搜索方向。
可选的,所述任一参与方设备基于本方的所述中间变化量计算本方的第二中间值信息,包括:
任一参与方设备基于本方的所述模型参数变化量的转置矩阵、所述模型参数变化量获得第一标量信息,基于本方的所述梯度信息变化量的转置矩阵、所述梯度信息变化量获得第二标量信息;
任一参与方设备与其他参与方设备进行交互以获得其他参与方设备的第三标量信息和第四标量信息;所述第三标量信息基于其他参与方设备的模型参数变化量的转置矩阵、模型参数变化量获得,所述第四标量信息基于其他参与方设备的梯度信息变化量的转置矩阵、梯度信息变化量获得;
任一参与方设备基于所述第一标量信息、所述第二标量信息、所述第三标量信息、所述第四标量信息、所述中间变化量计算本方第二中间值信息。
可选的,所述第一全局中间值为各参与方设备的第一中间值信息之和,所述第二全局中间值为各参与方设备的第二中间值信息之和。
可选的,所述目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息,包括:
目标参与方设备获取样本标签信息,并基于本方模型参数、特征信息以及其他参与方设备的第一数据信息获得样本标签预测信息;其中,所述第一数据信息基于其他参与方设备的模型参数、特征信息获得;
目标参与方设备基于所述样本标签预测信息及所述样本标签信息计算所述模型损失函数;
目标参与方设备判断所述模型损失函数是否满足预设条件,若是,则将当前步长信息作为最终的步长信息;否则,减少所述步长信息的值并重新计算所述模型损失函数。
可选的,所述基于本方模型参数、特征信息以及其他参与方设备的数据信息获得样本标签预测信息,包括:
目标参与方设备基于本方模型参数、特征信息计算模型参数的转置矩阵与特征信息的乘积获得第二数据信息;
目标参与方设备基于所述第二数据信息与其他参与方设备进行交互,获得其他参与方设备的第一数据信息;
目标参与方设备基于第一数据信息、第二数据信息以及预设模型函数获得所述样本标签预测信息。
本公开还提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如上述任意一项所述的方法。
本公开还提供了一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行上述任一所述的方法。
从上面所述可以看出,本公开提供的联邦学习模型训练方法、电子设备及存储介质,各参与方设备通过与其他参与方设备进行联合加密训练获得本方的梯度信息之后,基于模型参数变化量和梯度信息变化量与其他参与方设备进行联合训练从而获得各自的梯度搜索方向;之后,目标参与方设备基于梯度搜索方向以及模型损失函数计算步长信息;最后,各参与方设备基于梯度搜索方向、步长信息对本方的模型参数进行更新,从而无需计算Hessian矩阵的逆矩阵,相比于随机梯度下降方法、牛顿法和拟牛顿法其计算量小、通信量少,且可以保证快速收敛。
具体实施方式
为使本公开的目的、技术方案和优点更加清楚明白,以下结合具体实施例,并参照附图,对本公开进一步详细说明。
需要说明的是,除非另外定义,本公开实施例使用的技术术语或者科学术语应当为本公开所属领域内具有一般技能的人士所理解的通常意义。本公开实施例中使用的“第一”、“第二”以及类似的词语并不表示任何顺序、数量或者重要性,而只是用来区分不同的组成部分。“包括”或者“包含”等类似的词语意指出现该词前面的元件或者物件涵盖出现在该词后面列举的元件或者物件及其等同,而不排除其他元件或者物件。“连接”或者“相连”等类似的词语并非限定于物理的或者机械的连接,而是可以包括电性的连接,不管是直接的还是间接的。“上”、“下”、“左”、“右”等仅用于表示相对位置关系,当被描述对象的绝对位置改变后,则该相对位置关系也可能相应地改变。
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习、自动驾驶、智慧交通等几大方向。
机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
随着机器学习的快速发展,机器学习可应用于各个领域,如数据挖掘、计算机视觉、自然语言处理、生物特征识别、医学诊断、检测信用卡欺诈、证券市场分析和DNA序列测序等。对比传统机器学习方法,深度神经网络是一种较新的技术,用多层的网络结构建立机器学习模型,从数据中自动学习到表示特征。由于易于使用、实践效果好,在图像识别、语音识别、自然语言处理、搜索推荐等领域得到广泛应用。
联邦学习(Federated Learning),又可以称为联邦机器学习、联合学习、联盟学习等。联邦机器学习是一个机器学习框架,各个参与方联合建立机器学习模型,且在训练中只交换中间数据,而不直接交换各参与方的业务数据。
具体地,假设企业A、企业B各自建立一个任务模型,单个任务可以是分类或预测,而这些任务也已经在获得数据时由各自用户的认可。然而,由于数据不完整,例如企业A缺少标签数据、企业B缺少特征数据,或者数据不充分,样本量不足以建立好的模型,那么在各端的模型有可能无法建立或效果并不理想。联邦学习要解决的问题是如何在A和B各端建立高质量的机器学习模型,该模型的训练兼用A和B等各个企业的数据,并且各个企业的自有数据不被其他方知晓,即在不交换本方数据的情况下,建立一个共有模型。这个共有模型就好像各方把数据聚合在一起建立的最优模型一样。这样,建好的模型在各方的区域仅为自有的目标服务。
联邦学习的实施架构中包括至少两个参与方设备,各个参与方设备分别可以包括不同的业务数据,还可以通过设备、计算机、服务器等参与模型的联合训练;其中,各个参与方设备可以包括一台服务器、多台服务器、云计算平台和虚拟化中心中的至少一种。这里的业务数据例如可以是字符、图片、语音、动画、视频等各种数据。通常,各个参与方设备所包含的业务数据具有相关性,各个训练成员对应的业务方也可以具有相关性。单个参与方设备可以持有一个业务的业务数据,也可以持有多个业务方的业务数据。
在该实施架构下,可以由两个或两个以上的参与方设备共同训练模型。这里的模型可以用于处理业务数据,得到相应的业务处理结果,因此,也可以称之为业务模型。具体处理什么样的业务数据,得到什么样的业务处理结果,根据实际需求而定。例如,业务数据可以是用户金融相关的数据,得到的业务处理结果为用户的金融信用评估结果,再例如,业务数据可以是客服数据,得到的业务处理结果为客服答案的推荐结果,等等。业务数据的形式也可以是文字、图片、动画、音频、视频等各种形式的数据。各个参与方设备分别可以利用训练好的模型对本地业务数据进行本地业务处理。
可以理解,联邦学习可以分为横向联邦学习(特征对齐)、纵向联邦学习(样本对齐)与联邦迁移学习。本说明书提供的实施架构基于纵向联邦学习提出,即,各个参与方设备之间样本主体重叠,从而可以分别提供样本的部分特征的联邦学习情形。样本主体即待处理的业务数据对应的主体,例如金融风险性评估的业务主体为用户或者企业等。
在纵向联邦学习的二分类场景中,通常采用随机梯度下降(SGD)方法或者牛顿法及拟牛顿法来实现模型的优化。其中,随机梯度下降(SGD)方法的核心思想是利用损失函数对模型参数的一阶梯度来迭代优化模型,但是现有的一阶优化器只利用到了损失函数对模型参数的一阶梯度,收敛速度会比较慢;牛顿法是以二阶导数海森(Hessian)矩阵的逆矩阵乘以一阶梯度来引导参数更新,而这种方法的计算复杂度较高;拟牛顿方法即是将牛顿法中的二阶导数Hessian矩阵的逆用一个n阶矩阵来代替,但是这种方式的算法收敛速度仍然较慢。
有鉴于此,本公开实施例提供一种联邦学习模型训练方法,该方法可以提高纵向联邦学习中模型的收敛速度。如图1所示,所述联邦学习模型训练方法,包括:
步骤S101,任一参与方设备基于本方模型参数和特征信息与其他参与方设备进行联合加密训练,获得本方的梯度信息。
在本实施例中,至少两个参与方设备共同训练联邦学习模型,且各个参与方设备均可基于本参与方设备上的业务数据获得特征信息。在联邦学习模型的训练过程中,各个参与方设备基于加密后的模型参数、特征信息等信息与其他参与方设备进行交互,从而使得各个参与方设备均获得其各自的梯度信息。
步骤S103,任一参与方设备基于模型参数和梯度信息获取模型参数变化量和梯度信息变化量,并基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得本方的梯度搜索方向作为拟牛顿条件。
在本实施例中,在本实施例中,任一参与方设备基于模型参数和梯度信息,通过预设轮数的交互计算即可获得各个参与方设备的梯度搜索方向,各个参与方设备所获得的梯度搜索方向相当于牛顿法中的/>,因此无需直接计算海森矩阵H或者海森矩阵的逆矩阵/>即可,减小了数据的计算量和交互量。
步骤S105,目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息;其中,所述目标参与方设备为任一参与方设备中具有标签信息的参与方设备,所述模型损失函数为凸函数。
本实施例中,由于模型损失函数为凸函数,因此基于该模型损失函数的凸性,通过对其局部极值点的计算即可获得其全局极值点。基于步骤S103中所计算出的各个参与方设备的梯度搜索方向,选择一个步长信息对模型参数进行预更新,直至模型损失函数满足搜索停止条件,则基于该梯度搜索方向、步长信息对模型参数进行更新。
步骤S107,任一参与方设备基于所述梯度搜索方向、所述步长信息对本方的模型参数进行更新,直至所述联邦学习模型收敛。
可选的,在上述实施例中,任一参与方设备为参与联邦学习模型训练中的全部参与方设备中的任意一个,不区分该参与方设备是否具有标签信息。即本实施例中步骤S101、S103以及S107为参与联邦学习模型训练中的全部参与方设备均可执行的步骤。目标参与方设备为参与联邦学习模型训练中的全部参与方设备中具有标签信息的参与方设备,该目标参与方设备不仅执行步骤S101、S103以及S107的方法,也执行步骤S105中的方法。
在本实施例中,各参与方设备通过与其他参与方设备进行联合加密训练获得本方的梯度信息之后,基于模型参数变化量和梯度信息变化量与其他参与方设备进行联合训练从而获得各自的梯度搜索方向作为拟牛顿条件;之后,目标参与方设备基于梯度搜索方向以及模型损失函数计算步长信息;最后,各参与方设备基于梯度搜索方向、步长信息对本方的模型参数进行更新,从而无需计算Hessian矩阵的逆矩阵,相比于随机梯度下降方法、牛顿法和拟牛顿法其计算量小、通信量少,且可以保证快速收敛。
如图2所示,上述实施例所述方法应用于目标参与方设备Guest和除了目标参与方设备以外的其他参与方设备Host之间。其中,所述目标参与方设备Guest存储多个样本的第一特征信息和样本标签信息,所述其他参与方设备Host存储多个样本的第二特征信息。其他参与方设备可以仅包括一个参与方设备,也可以包括多个参与方设备,本实施例中以其他参与方设备仅包括一个参与方设备为例,详细说明基于标参与方设备Guest和其他参与方设备Host的联邦学习模型训练方法。
如图3所示,在一个具体的实施例中,基于双方共有信息(例如id信息)实现目标参与方设备Guest和其他参与方设备Host的数据对齐,对齐后的目标参与方设备Guest和其他参与方设备Host均包括id信息分别为1、2、3的多个样本。其中,其他参与方设备Host包括特征1、特征2以及特征3等多个第二特征信息;目标参与方设备Guest包括特征4(点击)、特征5、特征6等多个第一特征信息以及样本标签信息(购买)。
为了便于本公开实施例的后续表述,另目标参与方设备Guest和其他参与方设备Host的样本的数量为n。目标参与方设备Guest中每一条第一特征信息记为,目标参与方设备Guest中n个样本全部的第一特征信息列记为/>,每一个样本的样本标签为/>,n个样本全部的样本标签信息列为/>;其他参与方设备Host中每一条第二特征信息记为/>,其他参与方设备Host中n个样本全部的第二特征信息列为/>。其中,i表示n个样本中的第i个。
步骤S101,任一参与方设备基于本方模型参数和特征信息与其他参与方设备进行联合加密训练,获得本方的梯度信息。
在本实施例中,目标参与方设备Guest包括构建在目标参与方设备Guest本地的第一本地模型,第一本地模型包括第一模型参数;相应的,其他参与方设备Host包括构建在其他参与方设备Host本地的第二本地模型,第二本地模型包括第二模型参数/>。
在一些实施例中,在步骤S101中,采用同态加密算法或半同态加密算法对联合加密训练过程中的交互数据进行加密,例如可采用Paillier算法进行加密从而保证目标参与方设备Guest和其他参与方设备Host在联合训练的过程中不会泄露。如图4所示,步骤S101具体包括以下步骤:
步骤S201,其他参与方设备获取第一数据信息并发送至目标参与方设备,所述第一数据信息基于第二模型参数与第二特征信息获得。
在本步骤中,其他参与方设备Host获取其他参与方设备本地的第二本地模型的第二模型参数,并计算第二模型参数/>与第二特征信息的内积,从而获得第一数据信息,并将第一数据信息/>发送至目标参与方设备Guest。
可选,在本实施例中,第一数据信息包括第二模型参数/>转置矩阵/>与每一条第二特征信息/>的内积,因此第一数据信息包括与n个样本对应的n条信息。
可选的,在步骤S201中,其他参与方设备Host还可以计算第一正则项并发送至目标参与方设备Guest。其中,第一正则项为L2正则项,且第一正则项为,/>表示正则系数。
可选的,当处于第一次更新周期内时,第二模型参数为初始化后的模型参数初始值;当处于中间的更新周期内时,第二模型参数/>为第二本地模型在上一更新周期内更新后的模型参数。
步骤S203,目标参与方设备获取第二数据信息,所述第二数据信息基于第一模型参数与第一特征信息获得。
在本步骤中,目标参与方设备Guest获取第一本地模型的第一模型参数,并计算第一模型参数/>与第一特征信息的内积,从而获得第二数据信息/>。具体的,在本实施例中,第二数据信息/>包括第一模型参数/>转置矩阵/>与每一条第一特征信息/>的内积。
可选的,在本实施例中,目标参与方设备Guest还计算第二正则项。其中,第二正则项也为L2正则项,且第二正则项为,/>表示正则系数。
可选的,当处于第一次更新周期内时,第一模型参数为初始化后的模型参数初始值;当处于中间的更新周期内时,第一模型参数/>为第一本地模型在上一更新周期内更新后的模型参数。
在步骤S201与步骤S203中,由于在纵向联邦LR模型中,第一模型参数、第二模型参数/>一维向量,因此基于/>的第一数据信息以及基于/>获得的第二数据信息为矩阵相乘后的结果,当第一数据信息和第二数据信息被发送到对方时,对方无法恢复原本的数据信息,从而不会在步骤S201与步骤S203中数据传输过程中泄露明文信息,保证了双方数据的安全。
步骤S205,目标参与方设备基于所述第一数据信息、所述第二数据信息获得样本标签预测信息,对所述样本标签预测信息与所述样本标签信息的差值加密获得第一加密信息,将所述第一加密信息发送至所述其他参与方设备。
在本步骤中,目标参与方设备Guest基于所述第一数据信息、所述第二数据信息获得每一条样本的样本标签预测信息。其中,基于样本标签预测信息/>可判断样本的二分类的概率,从而可以解决纵向联邦模型中二分类的问题。可选的,在一些实施例中,,/>函数定义为/>。
之后,基于每一条样本的样本标签预测信息样本标签信息/>计算每一条样本的所述样本标签预测信息与所述样本标签信息的差值/>,并进行加密获得第一加密信息/>,其中,/>。由于采用了加密算法,加密后的信息在发送至其他参与方设备Host后不会泄露原始的样本标签信息,保证了数据的安全性。
可选的,本步骤中所采用的加密算法可以为半同态加密算法Paillier,或者也可采用其他可选的半同态加密算法或者同态加密算法,本实施例对此不作具体限定。
最后,目标参与方设备Guest将所述第一加密信息发送至所述其他参与方设备Host。
步骤S207,其他参与方设备基于所述第一加密信息、所述第二特征信息以及随机数获取第二加密信息并发送至目标参与方设备。
在本实施例中,其他参与方设备Host基于所述第一加密信息、所述第二特征信息以及随机数的乘积之和获得所述第二加密信息。其中,/>第i个样本的样本标签预测信息,/>第i个样本的样本标签,表示/>第i个样本的第二特征信息,/>第i个样本的随机数。通过随机数的增加,当其他参与方设备Host将第二加密信息/>发送至目标参与方设备Guest时,目标参与方设备Guest无法还原出/>的明文信息,也无法获得其他参与方设备的第二梯度信息,从而避免了数据的泄露。
步骤S209,目标参与方设备对所述第二加密信息进行解密获得第三解密信息,并将所述第三解密信息发送至所述其他参与方设备。其中,第三解密信息基于每一个样本的样本标签预测信息与样本标签信息的差值、第二特征信息以及随机数的积的累加之和获得。
在本步骤中,采用与S205中的加密算法对应的解密算法,目标参与方设备Guest对第二加密信息进行解密,获得第三解密信息/>。之后,目标参与方设备Guest将所述第三解密信息/>发送至所述其他参与方设备Host。
步骤S211,其他参与方设备接收第三解密信息,基于所述随机数获得第四解密信息,并基于所述第四解密信息获得第二梯度信息。
其他参与方设备Host接收第三解密信息后,可去掉随机数/>获得第四解密信息/>。由于第四解密信息/>是累加值,因此即使其他参与方设备Host已知/>也无法解析出每一条/>,从而避免了数据的泄露。
之后,其他参与方设备Host可以基于第四解密信息计算本方的第二梯度信息/>。
步骤S213,目标参与方设备根据所述样本标签预测信息与所述样本标签信息的差值以及第一特征信息计算第五明文信息,基于所述第五明文信息获得所述第一梯度信息。
在本步骤中,目标参与方设备Guest基于每一条样本的所述样本标签预测信息与所述样本标签信息的差值以及每一条样本的第一特征信息/>的乘积之和获得第五明文信息/>,并基于第五明文信息/>计算第一梯度信息。
在上述实施例中,步骤S205中还包括:目标参与方设备基于所述样本标签预测信息、所述样本标签信息计算损失函数Loss。可选的,损失函数Loss中还可以包括第一正则项和第二正则项,包括:
。
步骤S103,任一参与方设备基于模型参数和梯度信息获取模型参数变化量和梯度信息变化量,并基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得本方的梯度搜索方向作为拟牛顿条件。
可选的,在本实施例中,任一参与方设备基于所述模型参数变化量和所述梯度信息变化量,采用例如双向循环递归方法与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。即在本实施例中,目标参与方设备Guest获得第一梯度信息、其他参与方设备Host获得第二梯度信息之后,计算各自的模型参数变化量和梯度信息变化量,并基于双向循环递归方法进行预设轮数的交互计算,从而使得目标参与方设备Guest获得第一梯度搜索方向、其他参与方设备Host获得第二梯度搜索方向。同时,由于在本实施例中,目标参与方设备Guest与其他参与方设备Host所计算、发送以及接收的数据均是基于所述模型参数变化量、所述模型参数变化量的转置矩阵、所述梯度信息变化量、所述梯度信息变化量的转置矩阵中至少两个的向量乘积或标量乘积所获得的,而不涉及大矩阵的运算,因此整个过程中计算量和通信量都很小,从而可以保证模型的快速收敛。
在本实施例中,如图5所示,步骤S103中具体包括:
步骤S301,目标参与方设备Guest获取第一模型参数变化量和第一梯度信息变化量,其他参与方设备Host获取第二模型参数变化量和第二梯度信息变化量。
在本实施例中,为了便于表示,令表示梯度信息,其中,/>表示第一梯度信息,/>表示第二梯度信息。令t表示梯度信息/>的变化量/>,则/>表示第一梯度信息变化量,/>表示表示第二梯度信息变化量。s表示模型参数变化量/>,则/>表示第一模型参数变化量,/>表示第二模型参数变化量。
步骤S303,任一参与方设备基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得中间变化量;所述中间变化量用于表征所述梯度信息的大小。
可选的,在本实施例中,可采用双向循环算法进行梯度搜索方向的计算。其中,包括:在后向循环过程中,任一参与方设备基于基于所述第一中间信息与其他参与方设备进行预设轮数的交互计算,获得中间变化量。
其中,预设轮数为3-5中的一个,且后向循环与前向循环的轮数相同。
在本实施例中,具有第一梯度信息变化量和第一模型参数变化量/>的目标参与方设备Guest与具有第二梯度信息变化量/>和第二模型参数变化量/>的其他参与方设备Host,进行3-5轮的交互计算后,目标参与方设备Guest获得本方的中间变化量/>,其他参与方设备Host获得本方的中间变化量/>。
同时,在后向循环过程中,任一参与方设备基于本方的第一中间值信息,与其他参与方设备交换第一中间值信息并基于各参与方设备的第一中间值信息计算第一全局中间值,以根据所述第一全局中间值计算所述中间变化量。
在本实施例中,后向循环过程中的第一中间值信息包括、/>和/>、/>,目标参与方设备Guest与其他参与方设备Host分别基于本方的模型参数变化量、梯度信息变化量计算本方的第一中间值信息之后,需交换各参与方设备的第一中间值信息,从而获得第一全局中间值/>和/>。可选的,第一全局中间值可以为各参与方设备的第一中间值信息之和,或者也可以根据需求进行设置,本说明书对此不作限制。
具体的,目标参与方设备Guest与其他参与方设备Host分别基于本方梯度信息变化量的转置矩阵、模型参数变化量的乘积获得第一中间值信息、/>,交换各自的第一中间值信息/>、/>后获得第一全局中间值/>;再结合该第一全局中间值/>、模型参数变化量的转置矩阵以及梯度信息计算第一中间值信息/>、/>,再交换第一中间值信息/>、/>后计算第一全局中间值/>,最终基于/>计算本方的中间变化量。
下面结合具体实施例进一步详述本实施例中后向循环的步骤,包括:
步骤S401,目标参与方设备Guest初始化,其他参与方设备Host初始化。
步骤S403,对以下步骤迭代L轮,从/>到/>,/>从/>到/>。其中L表示预设轮数,且L=3~5;/>表示当前的循环轮数。
1). 其他参与方设备Host方计算中间过程变量;
2). 目标参与方设备Guest方计算中间过程变量;
3). 目标参与方设备Guest和其他参与方设备Host交换值后计算/>;
4). 其他参与方设备Host方计算中间过程变量;
5). 目标参与方设备Guest方计算中间过程变量;
6). 目标参与方设备Guest和其他参与方设备Host交换值后计算/>;
7). 其他参与方设备Host方计算中间变化量;
8). 目标参与方设备Guest方计算中间变化量。
在步骤S403中各步骤的各中间过程变量的计算与交换过程中,都是向量乘法或标量乘法的计算与交换,不涉及大矩阵的计算,因此在训练过程中的计算量和通信量都较少,不仅可以保证模型的快速收敛,还可以提高目标参与方设备与其他参与方设备的硬件处理速率。
步骤S305,任一参与方设备基于所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。
可选的,步骤S305进一步包括:任一参与方设备基于本方的所述中间变化量计算本方的第二中间值信息;任一参与方设备基于本方的所述第二中间值信息,与其他参与方设备交换第二中间值信息并基于各参与方设备的第二中间值信息计算第二全局中间值,以根据所述第二全局中间值计算所述梯度搜索方向。
在本实施例中,可采用双向循环算法进行梯度搜索方向的计算。其中,包括:在前向循环过程中,任一参与方设备基于所述模型参数变化量、所述模型参数变化量的转置矩阵、所述梯度信息变化量、所述梯度信息变化量的转置矩阵中至少两个的向量乘积或标量乘积获得第二中间值信息,并基于所述第二中间值信息、所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。
在本实施例中,具有中间变化量的目标参与方设备Guest与具有中间变化量/>的其他参与方设备Host,进行3-5轮的交互计算后,目标参与方设备Guest获得本方的第一梯度搜索方向/>,其他参与方设备Host获得本方的第二梯度搜索方向/>。
下面结合具体实施例进一步详述本实施例中前向循环的步骤,包括:
步骤S501,任一参与方设备基于本方的所述模型参数变化量的转置矩阵、所述模型参数变化量获得第一标量信息,基于所述梯度信息变化量的转置矩阵、所述梯度信息变化量获得第二标量信息。
在本实施例中,第一标量信息基于第一模型参数变化量的转置矩阵与第一模型参数变化量/>的积/>获得,第二标量信息基于第一梯度信息变化量的转置矩阵与第一梯度信息变化量的积/>获得。
步骤S503,任一参与方设备与其他参与方设备进行交互以获得其他参与方设备的第三标量信息和第四标量信息;所述第三标量信息基于其他参与方设备的模型参数变化量的转置矩阵、模型参数变化量获得,所述第四标量信息基于其他参与方设备的梯度信息变化量的转置矩阵、梯度信息变化量获得。
在本实施例中,第三标量信息基于第二模型参数变化量的转置矩阵与第二模型参数变化量/>的积/>获得,第四标量信息基于第二梯度信息变化量的转置矩阵与第二梯度信息变化量的积/>获得。
在本实施例中,目标参与方设备Guest与其他参与方设备Host交换第一标量信息、第二标量信息、第三标量信息以及第四标量信息,从而使得目标参与方设备Guest与其他参与方设备Host均具有上述信息。
步骤S505,任一参与方设备基于所述第一标量信息、所述第二标量信息/>、所述第三标量信息/>、所述第四标量信息/>、以及中间变化量/>、/>计算本方第二中间值信息,并与其他参与方设备交换第二中间值信息并基于各参与方设备的第二中间值信息计算第二全局中间值,以根据所述第二全局中间值计算所述梯度搜索方向。
在本实施例中,前向循环过程中的第二中间值信息包括,目标参与方设备Guest与其他参与方设备Host分别计算本方的第二中间值信息/>之后,需交换各参与方设备的第二中间值信息,从而获得第二全局中间值。可选的,第二全局中间值可以为各参与方设备的第二中间值信息之和,或者也可以根据需求进行设置,本说明书对此不作限制。
可选的,步骤S505进一步包括:
步骤S601,根据目标参与方设备Guest和其他参与方设备Host交换的第一标量信息、第二标量信息/>、第三标量信息/>、第四标量信息/>的值计算。
步骤S603,目标参与方设备Guest和其他参与方设备Host分别计算,其中对角矩阵。/>
步骤S605,其他参与方设备Host方计算,目标参与方设备Guest计算。
步骤S607,迭代L轮,从0到/>,/>从/>到/>。其中,L表示预设的循环轮数,且L=3~5中的一个;/>表示当前的循环轮数。
1). 其他参与方设备 Host方计算;
2). 目标参与方设备Guest方计算;
3). 目标参与方设备Guest和其他参与方设备Host交换值后计算/>;
4). 其他参与方设备Host方计算;
5). 目标参与方设备Guest方计算
步骤S609,其他参与方设备Host方得到第二梯度搜索方向,目标参与方设备Guest方得到第一梯度搜索方向/>。
在上述实施例中,由于计算过程中除了一次单位矩阵与向量的乘法,其他都是向量乘法或标量乘法,不涉及大矩阵的计算,从而减小了模型训练过程中的计算量;同时,双方的交互变量都是向量内积之后的标量结果,保证了数据的安全性,减小了数据传输过程中的通信量,不仅可以保证模型的快速收敛,还可以提高目标参与方设备与其他参与方设备的硬件处理速率。可选的,在一些具体的实施例中,对于同一份样本数据,在一次更新周期内,本公开实施例所述联邦学习模型训练方法仅需3个循环轮数的迭代,即可使得模型收敛;而采用梯度下降方法则需要数十轮迭代才可保证模型收敛,因此本公开实施例所述联邦学习模型训练方法能够提高模型的收敛速度。
步骤S105,目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息。
在一些实施例中,步骤S105中所述目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息,包括:
步骤S701,目标参与方设备获取样本标签信息,并基于本方模型参数、特征信息以及其他参与方设备的第一数据信息获得样本标签预测信息;其中,所述第一数据信息基于其他参与方设备的模型参数、特征信息获得。
在本实施例中,目标参与方设备Guest首先基于本方模型参数、特征信息计算模型参数的转置矩阵与特征信息的乘积获得第二数据信息;之后,目标参与方设备Guest基于所述第二数据信息/>与其他参与方设备Host进行交互,获得其他参与方设备Host的第一数据信息/>;最后,目标参与方设备Guest基于第一数据信息/>、第二数据信息/>以及预设模型函数获得所述样本标签预测信息。
可选的,预设模型函数为函数,,/>函数定义为/>。
步骤S703,目标参与方设备基于所述样本标签预测信息及所述样本标签信息计算损失函数。
在本实施例中,损失函数:
。
步骤S705,目标参与方设备判断所述损失函数是否满足预设条件,若是,则将当前步长信息作为最终的步长信息;否则,减少所述步长信息的值并重新计算所述损失函数。
在本实施例中,预设条件可以为Armijo条件。因此,可判断损失函数是否满足Armijo条件,包括:
,其中/>为超参数 (例如可以取值/>。
若损失函数满足Armijo条件,则将当前步长信息作为最终的步长信息;若损失函数不满足Armijo条件,则将减少所述步长信息的值例如为原来的1/2,并基于减少后的步长信息以及第一梯度搜索方向、第二梯度搜索方向更新双方的模型参数后重新计算损失函数,直至损失函数不满足Armijo条件。
之后,可基于获得的步长信息以及第一梯度搜索方向更新第一模型参数,其中,。
当双方的梯度变化稳定即阈值时,停止训练,模型更新完成。
需要说明的是,本公开实施例的方法可以由单个设备执行,例如一台计算机或服务器等。本实施例的方法也可以应用于分布式场景下,由多台设备相互配合来完成。在这种分布式场景的情况下,这多台设备中的一台设备可以只执行本公开实施例的方法中的某一个或多个步骤,这多台设备相互之间会进行交互以完成所述的方法。
需要说明的是,上述对本公开的一些实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作或步骤可以按照不同于上述实施例中的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序才能实现期望的结果。在某些实施方式中,多任务处理和并行处理也是可以的或者可能是有利的。
基于同一发明构思,与上述任意实施例方法相对应的,本公开还提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上任意一实施例所述的方法。
图6示出了本实施例所提供的一种更为具体的电子设备硬件结构示意图, 该设备可以包括:处理器1010、存储器1020、输入/输出接口1030、通信接口1040和总线 1050。其中处理器1010、存储器1020、输入/输出接口1030和通信接口1040通过总线1050实现彼此之间在设备内部的通信连接。
处理器1010可以采用通用的CPU(Central Processing Unit,中央处理器)、微处理器、应用专用集成电路(Application Specific Integrated Circuit,ASIC)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本说明书实施例所提供的技术方案。
存储器1020可以采用ROM(Read Only Memory,只读存储器)、RAM(Random AccessMemory,随机存取存储器)、静态存储设备,动态存储设备等形式实现。存储器1020可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器1020中,并由处理器1010来调用执行。
输入/输出接口1030用于连接输入/输出模块,以实现信息输入及输出。输入输出/模块可以作为组件配置在设备中(图中未示出),也可以外接于设备以提供相应功能。其中输入设备可以包括键盘、鼠标、触摸屏、麦克风、各类传感器等,输出设备可以包括显示器、扬声器、振动器、指示灯等。
通信接口1040用于连接通信模块(图中未示出),以实现本设备与其他设备的通信交互。其中通信模块可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信。
总线1050包括一通路,在设备的各个组件(例如处理器1010、存储器1020、输入/输出接口1030和通信接口1040)之间传输信息。
需要说明的是,尽管上述设备仅示出了处理器1010、存储器1020、输入/输出接口1030、通信接口1040以及总线1050,但是在具体实施过程中,该设备还可以包括实现正常运行所必需的其他组件。此外,本领域的技术人员可以理解的是,上述设备中也可以仅包含实现本说明书实施例方案所必需的组件,而不必包含图中所示的全部组件。
上述实施例的电子设备用于实现前述任一实施例中相应的方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
基于同一发明构思,与上述任意实施例方法相对应的,本公开还提供了一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行如上任一实施例所述的方法。
本实施例的计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。
上述实施例的存储介质存储的计算机指令用于使所述计算机执行如上任一实施例所述的方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
所属领域的普通技术人员应当理解:以上任何实施例的讨论仅为示例性的,并非旨在暗示本公开的范围(包括权利要求)被限于这些例子;在本公开的思路下,以上实施例或者不同实施例中的技术特征之间也可以进行组合,步骤可以以任意顺序实现,并存在如上所述的本公开实施例的不同方面的许多其它变化,为了简明它们没有在细节中提供。
另外,为简化说明和讨论,并且为了不会使本公开实施例难以理解,在所提供的附图中可以示出或可以不示出与集成电路(IC)芯片和其它部件的公知的电源/接地连接。此外,可以以框图的形式示出装置,以便避免使本公开实施例难以理解,并且这也考虑了以下事实,即关于这些框图装置的实施方式的细节是高度取决于将要实施本公开实施例的平台的(即,这些细节应当完全处于本领域技术人员的理解范围内)。在阐述了具体细节(例如,电路)以描述本公开的示例性实施例的情况下,对本领域技术人员来说显而易见的是,可以在没有这些具体细节的情况下或者这些具体细节有变化的情况下实施本公开实施例。因此,这些描述应被认为是说明性的而不是限制性的。
尽管已经结合了本公开的具体实施例对本公开进行了描述,但是根据前面的描述,这些实施例的很多替换、修改和变型对本领域普通技术人员来说将是显而易见的。例如,其它存储器架构(例如,动态RAM(DRAM))可以使用所讨论的实施例。
本公开实施例旨在涵盖落入所附权利要求的宽泛范围之内的所有这样的替换、修改和变型。因此,凡在本公开实施例的精神和原则之内,所做的任何省略、修改、等同替换、改进等,均应包含在本公开的保护范围之内。