CN111460528A - 一种基于Adam优化算法的多方联合训练方法及系统 - Google Patents
一种基于Adam优化算法的多方联合训练方法及系统 Download PDFInfo
- Publication number
- CN111460528A CN111460528A CN202010248683.8A CN202010248683A CN111460528A CN 111460528 A CN111460528 A CN 111460528A CN 202010248683 A CN202010248683 A CN 202010248683A CN 111460528 A CN111460528 A CN 111460528A
- Authority
- CN
- China
- Prior art keywords
- data
- training
- accumulated
- model
- moment
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/70—Protecting specific internal or peripheral components, in which the protection of a component leads to protection of the entire computer
- G06F21/71—Protecting specific internal or peripheral components, in which the protection of a component leads to protection of the entire computer to assure secure computing or processing of information
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Computer Hardware Design (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computer Security & Cryptography (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本说明书一个或多个实施例涉及一种基于Adam优化算法的多方联合训练方法及系统。所述方法包括:基于各个数据持有终端的通信状态确定参与模型参数更新的训练成员;服务器通过多方安全计算的方式获取累计数据;所述累计数据由所述训练成员基于自身的训练数据及其对应的模型参数确定;服务器基于所述累计数据以及样本标签参与累计梯度值的计算;所述累计梯度值用于所述训练成员计算自身的一阶矩和二阶矩,并基于所述一阶矩和二阶矩完成模型参数的更新;所述各个数据终端分别持有自身的训练数据以及与所述训练数据对应的模型参数;所述训练数据包括与实体相关的图像数据、文本数据或声音数据。其中,所述训练数据可以包括私有数据。
Description
技术领域
本说明书一个或多个实施例涉及多方数据合作,特别涉及一种Adam算法进行多方联合训练的方法和系统。
背景技术
随着人工智能技术的发展,梯度下降优化算法已逐渐应用于医疗、金融等领域。为得到更好的模型性能,梯度下降优化算法就需要更多的训练数据进行模型优化。在不同的企业或机构拥有不同的数据样本,将这些数据进行联合训练,能够提升模型精度,给企业带来很大的经济效益。因此,梯度下降优化算法常被用在多方参与的深度学习训练。Adam算法作为梯度下降优化算法中的一种,由于其收敛速度快、超参可进行自适应调整等优点被广泛使用。
因此,有必要提供一种基于Adam优化算法的联合训练方法,以提高有多方数据拥有者参与的联合训练的效率。
发明内容
本说明书实施例的一个方面提供一种基于Adam优化算法的多方联合训练方法;所述方法包括:基于各个数据持有终端的通信状态确定参与模型参数更新的训练成员;服务器通过多方安全计算的方式获取累计数据;所述累计数据由所述训练成员基于自身的训练数据及其对应的模型参数确定;服务器基于所述累计数据以及样本标签参与累计梯度值的计算;所述累计梯度值用于所述训练成员计算自身的一阶矩和二阶矩,并基于所述一阶矩和二阶矩完成模型参数的更新;所述一阶矩和所述二阶矩分别用于反映所述累计梯度值的期望和方差;其中,所述各个数据终端分别持有自身的训练数据以及与所述训练数据对应的模型参数;所述训练数据包括与实体相关的图像数据、文本数据或声音数据。
本说明书实施例的另一个方面提供一种基于Adam优化算法的多方联合训练系统;所述系统包括:训练成员确定模块,用于基于各个数据持有终端的通信状态确定参与模型参数更新的训练成员;累计数据获取模块,用于通过多方安全计算的方式获取累计数据;所述累计数据由所述训练成员基于自身的训练数据及其对应的模型参数确定;累计梯度值计算模块,用于基于所述累计数据以及样本标签参与累计梯度值的计算;所述累计梯度值用于所述训练成员计算自身的一阶矩和二阶矩,并基于所述一阶矩和二阶矩完成模型参数的更新;所述一阶矩和所述二阶矩分别用于反映所述累计梯度值的期望和方差;其中,所述各个数据终端分别持有自身的训练数据以及与所述训练数据对应的模型参数;所述训练数据包括与实体相关的图像数据、文本数据或声音数据。
本说明书实施例的另一个方面提供一种基于Adam优化算法的多方联合训练装置,所述装置包括处理器以及存储器;所述存储器用于存储指令,所述处理器用于执行所述指令,以实现所述基于Adam优化算法的多方联合训练方法对应的操作。
附图说明
本说明书将以示例性实施例的方式进一步描述,这些示例性实施例将通过附图进行详细描述。这些实施例并非限制性的,在这些实施例中,相同的编号表示相同的结构,其中:
图1是根据本说明书一些实施例所示的基于Adam算法进行多方联合训练系统的示例性应用场景图;
图2是根据本说明书一些实施例所示的一种Adam算法进行多方联合训练的方法的示例性流程图;以及
图3是根据本说明书的另外一些实施例所示的Adam算法进行多方联合训练方法的示例性示意图。
具体实施方式
为了更清楚地说明本申请实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单的介绍。显而易见地,下面描述中的附图仅仅是本申请的一些示例或实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图将本申请应用于其它类似情景。除非从语言环境中显而易见或另做说明,图中相同标号代表相同结构或操作。
应当理解,本说明书中所使用的“系统”、“装置”、“单元”和/或“模组”是用于区分不同级别的不同组件、元件、部件、部分或装配的一种方法。然而,如果其他词语可实现相同的目的,则可通过其他表达来替换所述词语。
如本说明书和权利要求书中所示,除非上下文明确提示例外情形,“一”、“一个”、“一种”和/或“该”等词并非特指单数,也可包括复数。一般说来,术语“包括”与“包含”仅提示包括已明确标识的步骤和元素,而这些步骤和元素不构成一个排它性的罗列,方法或者设备也可能包含其它的步骤或元素。
本说明书中使用了流程图用来说明根据本说明书的实施例的系统所执行的操作。应当理解的是,前面或后面操作不一定按照顺序来精确地执行。相反,可以按照倒序或同时处理各个步骤。同时,也可以将其他操作添加到这些过程中,或从这些过程移除某一步或数步操作。
在经济、文化、教育、医疗、公共管理等各行各业充斥的大量信息数据,对其进行例如数据分析、数据挖掘以及趋势预测等数据处理分析在越来越多场景中得到广泛应用。其中,通过数据合作的方式可以使多个数据拥有方获得更好的数据处理结果。例如,可以通过多方数据的联合训练来获得更为准确的模型参数。
在一些实施例中,基于梯度下降的优化算法进行多方联合训练的方法可以应用于在保证各方数据安全的情况下,各方协同训练机器学习模型供多方使用的场景。在这个场景中,多个数据方拥有自己的数据,他们想共同使用彼此的数据来统一建模(例如,线性回归模型、逻辑回归模型或深度神经网络模型等),但并不想各自的数据被泄露。例如,互联网储蓄机构A拥有一批用户数据,政府银行B拥有另一批用户数据,基于A和B的用户数据确定的训练样本集可以训练得到比较好的机器学习模型。A和B都愿意通过彼此的用户数据共同参与模型训练,但因为一些原因A和B不愿意自己的用户数据信息遭到泄露,或者至少不愿意让对方知道自己的用户数据信息。
在一些实施例中,基于梯度下降的优化算法进行多方联合训练的系统可以使多方的数据在不受到泄露的情况下,通过多方数据的联合训练来得到共同使用的机器学习模型,达到一种共赢的合作状态。
在一些实施例中,可以使用梯度下降法(又称SGD算法)进行模型学习和训练。梯度下降法采用单个样本依次进行训练,训练样本的差异性带来很多噪声,使得损失函数并不是每次迭代都向着整体最优化方向,收敛速度逐渐降低,训练时间长。在一些实施例中,也可以采用小批量梯度下降法(又称MBGD算法)进行模型学习和训练。小批量梯度下降法减少了噪声引起的损失函数波动,但很难选择合适的学习率,学习率难以随着训练迭代次数进行自适应调整。另外对于部分损失函数,小批量梯度下降法容易得到局部最小值,而无法得到全局最小值,以至于模型性能无法达到最优。在一些实施例中,还可以采用Momentum算法进行模型学习和训练。Momentum算法进行在梯度中引入动量,相比于SGD算法和MBGD算法,Momentum算法的稳定性较强。但当动量与当前梯度方向相反时会降低收敛速度,此外初始训练梯度较大时,原始累积梯度权重较小,训练后期梯度较小时,原始累积梯度权重较大,需要人为调节,无法实现参数的自适应性。
在一些实施例中,还可以采用Adam优化算法进行模型学习和训练。Adam优化算法利用一阶矩和二阶矩提高了训练收敛速度,超参可进行自适应调整。但是,Adam优化算法对于计算一阶矩和二阶矩过程依赖于迭代次数进行更新,一旦在多方联合训练中发生通讯中断,一阶矩和二阶矩无法完成累加,导致模型更新有误,影响后续训练。
在一些实施例中,可以对基于Adam优化算法进行的模型训练进行进一步的优化,在一些实施例中,可以选取通信状态良好的数据持有者的样本数据参与一阶矩和二阶矩的计算,让模型训练继续进行,保证部分训练成员通信中断不影响整体训练进度。例如,可以基于通信良好的成员持有的训练数据以及模型参数来计算累计梯度Δh,参与训练的成员可以基于所述累计梯度来计算一阶矩二阶矩进而通过所述一阶矩和二阶矩计算更新模型参数WAi,作为所述多方联合训练模型的参数。
在一些实施例中,由于通信中断而未参与一阶矩和二阶矩计算的部分成员,在恢复通信连接状态后可以继续计算更新一阶矩和二阶矩,保证模型训练的准确性。例如,恢复通信连接的成员可以由参与训练的成员计算得到的所述累计梯度Δh计算自身的一阶矩和二阶矩并更新下一次迭代的一阶矩和二阶矩的初始值,从而确保下一次迭代训练相关参数的准确性。
图1为根据本说明书的一些实施例所示的基于Adam算法进行多方联合训练系统的示例性应用场景图。
需要注意的是,图1仅作为示例性的说明,图1中数据拥有者的数量可以为两方,在其他实施例中,还可以包括第三方数据拥有者、第四方数据拥有者以至第N方数据拥有者等。
在一些实施例中,多方联合训练系统100包括第一终端110、第二终端120、服务器130以及网络140。其中,第一终端110可以理解为第一方数据拥有者,包括处理设备110-1,存储设备110-2;第二终端120可以理解为第二方数据拥有者,包括处理设备120-1,存储设备120-2;服务器130包括处理设备130-1,存储设备130-2。在一些实施例中,服务器130可以属于参与联合训练的多个数据拥有者中的一个,也包含自身的训练样本数据。在另一些实施例中,服务器130也可以属于独立于所述各个数据拥有者的可信第三方,不包含训练样本数据,仅进行计算和数据存储。
多方数据拥有者可以通过多方安全求交得到训练集合和样本标签,并同时进行编号,分别得到各方数据拥有者的训练样本。所述多方安全求交,是各方在不暴露自身数据的情况下运算得到多方数据的交集。由于多方数据拥有者采用多方安全求交技术,多方数据拥有者仅知道各自拥有的训练样本但不知道其他任意一方的训练样本。
第一终端110、第二终端120均可以是带有数据获取、存储和/或发送功能的设备。在一些实施例中,第一终端110、第二终端120可以包括但不限于移动设备、平板电脑、笔记本电脑、台式电脑等或其任意组合。在一些实施例中,第一终端110和第二终端120可以接收来自服务器130的相关数据。例如,第一终端110可以接收来自服务器的迭代次数。例如,第一终端110和第二终端120可以从服务器130处接收与迭代次数差值对应的梯度值。
第一终端和第二终端的处理设备110-1和120-1可以进行数据和/或指令处理。处理设备110-1和120-1可以对数据进行计算,也可以执行相关算法和/或指令。例如,第一终端110的处理设备110-1可以接收来自服务器130的梯度集合,并用所存储的一阶矩原始值计算一阶矩,也可以利用一阶矩和二阶矩计算得到第一终端的更新后的模型参数。
第一终端和第二终端的存储设备110-2和120-2可以存储对应处理设备110-1和120-1执行使用的数据和/或指令,处理设备110-1和120-1可以通过执行或使用所述数据和/或指令以实现本说明书中的示例性方法。存储设备110-2和120-2可以分别用于存储一阶矩和二阶矩的初始值;也可以存储指示第一终端和第二终端执行操作的相关指令。存储设备110-2和120-2还可以分别存储经处理设备110-1和120-1处理后数据。例如,存储设备110-2和120-2还可以分别存储各终端的迭代次数和相关梯度值。在一些实施例中,存储设备110-2和存储设备120-2也可以是一个存储设备,其中,第一终端和第二终端只能从该存储设备中获取自己存储的数据。在一些实施例中,存储设备可包括大容量存储器、可移动存储器、易失性读写存储器、只读存储器(ROM)等或其任意组合。
服务器130可以是带有数据获取、存储和/或发送功能的设备,例如,云端服务器,终端处理设备等。在一些实施例中,服务器130可以接收来自第一终端110和第二终端120的相关数据。例如,服务器130可以接收来自第一终端110的训练数据的样本个数和/或特征维度。
服务器的存储设备130-2可以存储处理设备130-1执行使用的数据和/或指令,处理设备130-1可以通过执行或使用所述数据和/或指令以实现本说明书中的示例性方法。例如,存储设备130-2和120-2可以用于存储迭代次数;也可以存储指示第一终端和第二终端执行操作的相关指令。在一些实施例中,存储设备可包括大容量存储器、可移动存储器、易失性读写存储器、只读存储器(ROM)等或其任意组合。
网络140可以促进信息和/或数据的交换。在一些实施例中,多方联合训练系统100(例如,第一终端110(处理设备110-1和存储设备110-2)、第二终端120(处理设备120-1和存储设备120-2)和服务器(处理设备130-1和存储设备130-2))的一个或以上部件可以经由网络140向所述系统100中的其他有数据传输需求的部件发送信息和/或数据。例如,第二终端120的处理设备120-1可以经由网络140从服务器130中获得迭代次数。又例如,第一终端110的处理设备110-1可以通过网络140从服务器130的存储设备110-2中获取梯度集合。在一些实施例中,网络140可以为任意形式的有线或无线网络,或其任意组合。本说明书一个或多个实施例中的系统,可以由数据传输模块及若干个数据传输模块组成。
在一些实施例中,在以服务器为执行主体的系统中,所述数据传输模块可以包括第一数据传输模块、第二数据传输模块、第三数据传输模块。所述数据处理模块可以包括训练成员确定模块、累计数据获取模块、累计梯度值计算模块。上述模块均在应用场景所介绍的计算系统中执行,各模块包括各自的指令,指令可存储在存储介质上,指令可在处理器中执行。不同的模块可以位于相同的设备上,也可以位于不同的设备上。它们之间可以通过程序接口、网络等进行数据的传输,可以从存储设备中读取数据或者将数据写入到存储设备中。
训练成员确定模块,用于基于各个数据持有终端的通信状态确定参与模型参数更新的训练成员;在一些实施例中,所述通信状态包括数据传输的丢包率和/或等待时间;当所述丢包率在预设的丢包率阈值内,和/或所述等待时间在预设的时间阈值内时,所述训练成员确定模块还用于将所述完成数据传输的数据持有终端确定为参与模型参数更新的训练成员。
累计数据获取模块,用于通过多方安全计算的方式获取累计数据;所述累计数据由所述训练成员基于自身的训练数据及其对应的模型参数确定。
累计梯度值计算模块,用于基于所述累计数据以及样本标签参与累计梯度值的计算;所述累计梯度值用于所述训练成员计算自身的一阶矩和二阶矩,并基于所述一阶矩和二阶矩完成模型参数的更新。
在一些实施例中,当所述模型包括深度神经网络模型,且所述神经网络模型中部分层的模型参数在服务器端更新时;所述累计梯度值计算模块还用于:基于所述累计数据以及样本标签计算累计损失值;基于所述累计损失值,确定所述累计梯度值;并将所述累计梯度值传输给所述训练成员。
在一些实施例中,第一数据传输模块,用于将模型参数更新的预设迭代次数以及预设批处理参数传输给各个数据持有终端;所述预设迭代次数以及所述预设批处理参数用于对各个数据持有终端的训练数据进行编号。
在一些实施例中,第二数据传输模块,用于基于所述训练成员当前的迭代次数与所述预设迭代次数之间的迭代次数差值,将与所述迭代次数差值对应个数的累计梯度值传输给所述训练成员。
在一些实施例中,第三数据传输模块,当通信状态恢复通信连接时,用于将基于未参与模型参数更新的其他成员的当前迭代次数与所述预设迭代次数之间的迭代次数差值,将与所述迭代次数差值对应个数的累计梯度值传输给所述其他成员;所述累计梯度值用于所述其他成员计算一阶矩和二阶矩。
应当理解,本说明书一个或多个实施中的所述系统及其模块可以利用各种方式来实现。例如,在一些实施例中,系统及其模块可以通过硬件、软件或者软件和硬件的结合来实现。其中,硬件部分可以利用专用逻辑来实现;软件部分则可以存储在存储器中,由适当的指令执行系统,例如微处理器或者专用设计硬件来执行。本领域技术人员可以理解上述的方法和系统可以使用计算机可执行指令和/或包含在处理器控制代码中来实现,例如在诸如磁盘、CD或DVD-ROM的载体介质、诸如只读存储器(固件)的可编程的存储器或者诸如光学或电子信号载体的数据载体上提供了这样的代码。本申请的系统及其模块不仅可以有诸如超大规模集成电路或门阵列、诸如逻辑芯片、晶体管等的半导体、或者诸如现场可编程门阵列、可编程逻辑设备等的可编程硬件设备的硬件电路实现,也可以用例如由各种类型的处理器所执行的软件实现,还可以由上述硬件电路和软件的结合(例如,固件)来实现。
需要注意的是,以上对于处理设备及其模块的描述,仅为描述方便,并不能把本申请限制在所举实施例范围之内。可以理解,对于本领域的技术人员来说,在了解该系统的原理后,可能在不背离这一原理的情况下,对各个模块进行任意组合,或者构成子系统与其他模块连接。
图2是根据本说明书的一些实施例所示的基于Adam优化算法进行联合训练方法的示例性流程图。
本说明书中的变量名称、公式仅为更好地理解本说明书所述的方法。在应用本说明书时,基于常见的运算原理和机器学习原理,可以对下述过程、变量名称、公式做各种非实质性的变换,例如调换行或列的次序、在矩阵乘法时变换为等价形式、以其他计算形式来表示同一计算等。
需要说明的是,在本说明书中所涉及的一个或多个实施例的方法不仅可以应用于数据垂直切分的多方进行联合训练,也可以应用于数据水平切分的多方进行联合训练。所述数据垂直切分是指各个数据方持有的训练样本的数据ID相同、特征维度不同;所述数据水平切分是指各个数据方持有的训练样本的数据ID不同、特征维度相同。在以下一个或多个实施例中,以数据垂直切分为例进行说明。
在本说明书中,约定按以下方式表示:
在联合训练中,A1、A2、…、AN为联合训练中的数据所有者,N为数据所有者总数,服务器D作为模型训练的驱动中心。在一些实施例中,所述模型可以是一个线性模型。在一些实施例中,所述模型也可以是一个深度神经网络模型。为了提高通信效率,在一些实施例中,所述深度神经网络模型前面几层的模型参数计算在各数据方A1、A2、…、AN执行,后面几层模型参数在服务器D中执行。图2给出的实施例为深度神经网络学习模型,且部分网络层的模型参数在服务器D上计算更新,部分网络层的模型参数在各数据方计算更新。
在本说明书的表示中,在进行说明之前对联合训练中的数据所有者A1、A2、…、AN以及服务器D进行说明。
服务器D在联合训练开始前预先设定模型迭代次数T和各个数据所有者A1、A2、…、AN每次训练的样本数batch_size。在一些实施例中,迭代次数T可以是一个大于等于0的自然数,用以控制整个模型的计算和更新的次数。在一些实施例中,最大迭代次数可以设定为100、200或300等。在初始化时,可以将迭代次数T初始化为0,然后以此累加直到达到预设的最大迭代次数。每次训练样本数batch_size可以用来确认每次计算的样本数量。在一些实施例中,batch_size可以设定为大于0的自然数,如batch_size可以为100,200,300等。
下面对数据所有者A1、A2、…、AN进行说明。为方便说明,本说明书的一些实施例,以Ai方数据所有者为例进行详细说明,Ai可以是联合训练中的数据所有者A1、A2、…、AN中的任意一个。在一些实施例中,第Ai方数据所有者也可以称之为第i终端或第i方。数据所有者Ai的内部持有Ai方的当前迭代次数tAi,权重矩阵WAi和偏置向量bAi。Ai方的迭代次数tAi是一个大于等于0的自然数。
在一些实施例中,数据所有者Ai可以利用训练迭代次数T和每次训练样本数batch_size对自身参与训练的样本数据进行编号,所述编号包括每次参与迭代训练的样本数据的起始编号batch_begin_i和终止编号batch_end_i。数据所有者Ai的内部还含有当前训练的起始编号batch_begin_now_i和终止编号batch_end_now_i;以及需要进行训练的起始编号batch_begin_i和终止编号batch_end_i。在一些实施例中,需要进行训练的样本数据的起始编号batch_begin_i和终止编号batch_end_i可以由迭代次数T简单计算得到。
在本说明书的约定中,
batch_begin_i=batch_size*(T-1),
batch_end_i=batch_size*T。
在一些实施例中,以上基于迭代次数进行编号可以是由数据所有者Ai完成。在一些实施例中,基于迭代次数进行编号的这一过程也可以是服务器D完成。服务器D计算完成后将需要进行训练的起始编号batch_begin_i和终止编号batch_end_i直接发送给对应的数据所有者Ai。
步骤210,数据所有者进行数据对齐和拆分。
多方数据所有者Ai(i=1,...,N)之间利用多方安全求交得到具有相同ID的数据交集X,对交集进行编号,分别得到Ai(i=1,...,N)的数据集XAi。多方数据所有者Ai中的任意一方可以持有训练样本的标签数据。在本说明书中,假定数据所有者A1包含训练的标签数据。多方数据所有者Ai将多方安全求交得到的数据交集XAi进行拆分得到训练集XtrAi,验证集XvalAi,测试集XieAi;并得到训练集XtrAi的数据维度,样本个数nAi以及特征个数fAi。在一些实施例中,也可以仅得到训练集XtrAi,省略验证集XvajAi和测试集XteAi。
步骤220:服务器及数据所有者初始化。
该步骤包括服务器以及数据所有者将模型相关数据进行初始化设置,以及彼此之间进行通信连接。
在一些实施例中,服务器D确定深度神经网络的结构,即网络的层数和每层的节点数,其中分为数据所有者内部的网络结构和服务器D的网络结构。确定服务器D以及数据所有者A1、A2、…、AN可以通信,例如,服务器D以及数据所有者A1、A2、…、AN确定各自的通信IP地址。
服务器D设定每次训练的样本数batch_size,例如,设定batch_size=100,并将模型迭代次数T置为0。
在一些实施例中,数据所有者Ai的迭代次数初始化tAi,即tAi=0;初始化当前参与训练的样本数据的起始编号batch_begin_now_i和batch_end_now_i,即batch_begin_now_i=0,batch_end_now_i=0;初始化权重矩阵WAi和偏置向量bAi;初始化一阶矩和二阶矩即初始化梯度集合ΔH为空集。
在一些实施例中,数据所有者之间进行通信连接时,数据拥有者Ai将本次参与迭代训练的样本个数nAi进行广播,并确认与其他数据拥有者样本个数是否相同。若存在任意一组数据拥有者Ai、Aj,使得nAi≠nAj,则返回步骤210重新进行多方安全求交运算。
在一些实施例中,服务器与所有数据所有者进行通信连接时,服务器D将数据所有者的网络结构发送给所有的数据拥有者Ai,所有的数据拥有者Ai将各自的样本个数nAi发送给服务器。在一些实施例中,服务器D根据IP地址与所有的数据所有者Ai进行通信连接。
步骤230:服务器驱动多方训练。
当前述准备工作完成后,服务器可以驱动多方开始模型训练,此处多方即为多个数据方。由于本实施例部分的深度神经网络模型在各数据方和服务器均设有模型参数的训练,所以以下步骤会分别涉及服务器端和各数据端的模型参数更新。具体的训练过程详见步骤231~步骤239。
步骤231:数据所有者更新当前训练数据的起始编号和终止编号。
服务器D将迭代次数T增加1,即令T=T+1,并将更新后的迭代次数T下发至所有数据所有者Ai。数据所有者Ai基于接收到的迭代次数T计算迭代次数差值ΔtAi=T-tAi;并更新迭代次数tAi=T。计算需要参与本次迭代训练的样本数据的起始编号batch_begin_i=batch_size*(T-1)和终止编号batch_end_i=batch_size*T。
各数据所有者Ai根据步骤230中得到的终止编号batch_end_i与当前参与训练的终止编号的batch_end_now_i进行比较,判断是否需要更新当前参与训练的起始/终止编号。若batch_end_now_i小于batch_end_i,则将batch_end_i的值赋给batch_end_now_i,将batch_begin_i的值赋给batch_begin_now_i。
在本说明书的约定中,若batch_end_now_i<batch_end_i,则令batch_begin_now_i=batch_begin_i,batch_end_now_i=batch_end_i;以该公式中,符号“=”指的是将“=”后项的值赋给前项。
数据所有者Ai根据需要进行训练的样本数据的起始编号batch_begin_now_i和终止编号batch_end_now_i读取对应的训练数据XAi以及A1中的标签数据Y。
由于数据所有者Ai可能出现断线等情况,将batch_end_i的值赋给batch_end_now_i能够确保数据所有者Ai读取的样本数据是所需要的数据。
在一些实施例中,各数据所有者Ai还可以根据步骤230中得到的起始编号batch_begin_i与当前参与训练的起始编号的batch_begin_now_i进行比较,判断是否需要更新当前参与训练的样本数据的起始/终止编号。
在一些实施例中,当数据所有者中有部分数据所有者产生通信中断时,所述起始编号和终止编号可以反应发生通信中断的数据所有者中已经参加迭代训练的样本个数,在通信状态恢复连接后,可以基于上述编号的比较关系让迭代训练从没有参与训练的样本数据继续进行。
步骤233:基于通信状态确定参与训练的成员,并确定累积梯度值。
数据所有者之间A1、A2、…、AN通过秘密分享的方式发送XAi和WAi,并通过通信状态确定参与本次迭代训练的训练成员。
在一些实施例中,服务器D可以根据通信状态,确定参与本次迭代训练的训练成员。在一些实施例中,若数据所有者A1、A2、…、AN的通信状态处于连接状态时,则确定其为参与训练的训练成员;若数据所有者A1、A2、…、AN的通信状态处于通信断开时,则确定其不参与训练。在一些实施例中,通信状态可以基于丢包率和等待时间门限确定。所述丢包率指测试中所丢失数据包数量占所发送数据包的比率,所述等待时间门限是指数据传输或反馈的最长时间。在一些实施例中,可以基于丢包率和等待时间分别设定对应的预设值,以判定数据所有者A1、A2、…、AN的通信状态。例如,当某一数据所有者在数据传输过程中的丢包率和等待时间都分别小于对应的预设值时,则可判定该数据所有者通信连接。反之,则判定为通信中断。再例如,当某一数据所有者的数据传输过程中的丢包率和等待时间的其中一个小于对应的预设值时,则可判定该数据所有者通信连接。
在一些实施例中,假设当前本次迭代训练的训练成员的总数为k(k≤N),成员编号集合K={1,...,j,...}。各个训练成员利用秘密共享算法得到计算结果h0=<XA1,…,XAj,…>×<WA1,…,WAj,….>,并将所述计算结果h0发送给服务器D。
在一些实施例中,在秘密共享算法运算过程中,各个训练成员可以将当前的迭代次数tAj作为样本验证信息,确保每次参与训练的样本数据都具有相同的样本编号。对应地,在一些实施例中,各个训练成员Aj可以将所有计算得到的ΔtAj发送服务器D;持有样本标签的训练成员将样本标签Y发送给服务器D。
步骤235:服务器更新模型参数。
在一些实施例中,服务器D接收到当前的迭代次数tAj、隐层计算结果h0和标签Y后,验证接收到的tAj和步骤230中发送的T是否相同,如果相同则继续执行,如果不同则继续等待Aj发送正确的h0。接收到的tAj和步骤230中发送的T是否相同是为了保证服务器接收到了正确的h0。如果不验证tAj和步骤230中发送的T是否相同,数据所有者Aj可能由于断线、计算过慢等原因发送不是本轮迭代的计算结果h0,从而导致服务器计算梯度时的结果发生偏差。
在一些实施例中,服务器D可以利用激活函数对h0进行激活。在一些实施例中,所用激活函数包括但不限于Sigmoid、Tanh、ReLU、Leaky ReLU、Maxout、Softmax等。引入激活函数可以给神经元引入了非线性因素。如果不用激活函数,无论神经网络有多少层,输出都是输入的线性组合。在本说明书中,以激活函数为Sigmoid为例进行说明。
服务器D基于h0得到服务器输入层的梯度值Δh,并将梯度值Δh存储到梯度集合ΔH中。
在一些实施例中,服务器D的网络结构可以是多层的。例如,服务器D可以具有多个隐藏层。又例如,服务器D可以具有输入层和输出层。当服务器D的实际输出与期望输出不符时,进入的反向传播阶段。所谓反向传播,指的是误差通过输出层,按梯度下降的方式依次修正各层权值,向隐层、输入层逐层反传。通过所述反向传播,服务器D可以更新在输入层的梯度值Δh。
在本说明书的约定中,
在一些实施例中,当服务器D的网络结构具有多层时,服务器基于输入层的梯度值Δh计算下面每一层的梯度值,从而完成服务器侧的模型参数的更新。
步骤237:参与训练的训练成员更新模型参数。
在一些实施例中,服务器D从梯度集合ΔH中选取对应的梯度值发送给对应的训练成员;训练成员分别基于接收到的梯度值以及自身持有的其他数据计算一阶矩和二阶矩,并基于一阶矩和二阶矩计算更新各自的模型参数,即权重矩阵WAi。
在一些实施例中,服务器D根据ΔtAj从梯度集合ΔH中依次选取最后ΔtAj个梯度矩阵组成为参与此次迭代训练的数据所有者Aj的对应的梯度集合ΔHj,并发送至对应的数据所有者Aj。
所有参与此次迭代训练的数据所有者Aj根据各自对应的梯度集合ΔHj,利用Adam优化算法,按照如下公式计算更新各自的模型参数。
在本说明书的约定中:
步骤239:未参与训练的其他成员计算一阶矩和二阶矩。
在一些实施例中,对于未参与训练的成员,如果其通信状态恢复,服务器D也可以基于迭代次数差值把梯度值下发给对应的成员,各成员计算一阶矩,二阶矩,并完成模型超参数的更新。
上述过程说明了基于Adam优化算法继续的一次模型参数的更新过程,可迭代步骤230至步骤235直至收敛或者完成设定的迭代训练次数。
在进行多次模型迭代的过程中,如果任意一数据所有者断线后重连,都可以在下一次循环进行至步骤231时加入训练。重新加入训练的任意一数据所有者可以获取未计算的多个梯度矩阵Δh(即梯度集合ΔH),并通过前述公式将多个梯度矩阵Δh代入得到最新的模型。例如,数据所有者Am在某次迭代训练的开始阶段由于网络不稳定处于通信断开状态,因而被确定为不参与模型参数更新的其他成员。当此次迭代训练过程结束时,即参与训练的成员基于所述梯度矩阵计算完一阶矩、二阶矩以及模型参数WAi时,检测到数据所有者Am的通信状态又恢复通信连接了,此时,恢复通信连接的数据所有者Am可以重新加入训练,即根据迭代次数差值获取对应个数的梯度矩阵,并计算更新一阶矩和二阶矩的值,为下一次迭代训练做准备。
在一些实施例中,可以用验证集对模型进行验证,以实现模型超参数的调整。可以利用测试集对训练好的模型进行测试,如果结果不满足需求则可通过增加训练集、增加训练迭代轮数、数据正则化等方法,提高模型的准确度。
需要说明的是,以上描述中涉及的一个或多个实施例的方法是以数据垂直切分的数据多方进行联合训练为例进行的示例性说明。上述涉及的一个或多个实施例的方法也可以应用于数据水平切分的多方进行联合训练。如果将上述实施方案中的垂直划分换成水平划分情况的话,对应的发生的变化包括但不限于如下几个地方:
(1)步骤210,多方数据所有者Ai(i=1,...,N)之间利用多方安全求交得到具有相同特征的数据交集X,对交集进行编号,分别得到Ai(i=1,...,N)的数据集XAi。多方数据所有者Ai分别持有自身训练数据的样本标签。
(2)步骤222,数据拥有者Ai将特征个数fAi进行广播,并确认与其他数据拥有者特征个数是否相同。若存在任意一组数据拥有者Ai、Aj,使得fAi≠fAj,则返回步骤210重新进行求交运算,否则继续进行步骤223。
(3)步骤223,服务器D将数据所有者的网络结构发送给所有的Ai,同时Ai将各自的特征个数fAi发送给服务器。
(4)步骤231,数据所有者Ai根据参与训练的训练数据的起始编号batch_begin_i和终止编号batch_end_i得到此次训练的训练数据XAi以及与训练数据对应的样本标签Y,因为每个数据所有者Ai都分别持有自身的样本标签。
(5)利用秘密共享算法按照如下公式得到计算结果:
h0=<XA1,…,XAj,…>T×<WA1,…,WAj,….>。
图3为根据本说明书的一些实施例所示的基于Adam优化算法的多方联合训练方法的示例性流程图。
在一些实施例中,方法300中的一个或以上步骤可以在图1所示的多方联合训练系统100中实现。例如,方法300中的一个或以上步骤可以作为指令的形式存储在存储设备120中,并被处理设备110调用和/或执行。
步骤310,服务器基于各个数据持有终端的通信状态确定参与模型参数更新的训练成员。在一些实施例中,步骤310可以由训练成员确定模块执行。
在一些实施例中,所述服务器可以理解为具有数据和/或指令接收、处理能力的设备,该设备可以是终端处理设备也可以云端处理设备。在一些实施例中,所述服务器可以来自于独立于各个数据持有方的可信第三方;也可以来自于各数据持有方的其中一方。
在一些实施例中,各个数据持有终端可以是图2部分描述的各个数据方或各个数据所有者Ai。各个数据所有者Ai分别持有自身的训练数据以及与训练数据相对应的模型参数。在一些实施例中,各个数据所有者Ai持有的训练数据和模型参数不对外公开。服务器可以通过多方安全计算的方式(例如,秘密共享等方式)获取各个数据所有者Ai的样本数据。其中,各个数据持有终端持有的训练数据可以包括垂直划分和水平划分两种情况,所述数据垂直切分是指各个数据持有终端的训练数据的样本个数相同,特征维度不同;所述数据水平切分是指各个数据持有终端的训练数据的特征维度相同,但样本个数不同。当各个数据持有终端的训练数据分别为垂直划分和水平划分两种情况时,本说明书一个或多个实施例发生的变化可参见图2部分相关说明。
在一些实施例中,参与模型参数更新的训练成员可以从各个数据持有终端中确定。在一些实施例中,参与模型参数更新的训练成员可以通过各个数据持有终端的通信状态确定。若数据持有终端的通信状态处于通信连接时,则服务器确定其参与训练,即为参与模型参数更新的训练成员;若数据持有终端的通信状态处于通信断开时,则服务器确定其为不参与训练,即为不参与模型参数更新的数据持有终端。本说明书一个或多个实施例中,训练成员可以理解为参与模型参数更新或参与训练的数据持有终端。在每一次迭代训练中,服务器都会通过各个数据持有终端的通信状态确定参与此次迭代训练或此次模型参数更新的训练成员,使得模型训练能够不受通信断开的影响继续进行。
在一些实施例中,服务器可以通过丢包率和等待时间判断各个数据持有终端的通信状态。具体地,服务器可以设定可以基于丢包率和等待时间设定预设值,判定各个数据持有终端的通信状态。例如,当某一数据持有终端的丢包率和等待时间都小于预设值,则服务器可判定该数据所有者通信连接。又例如,当某一数据持有终端的丢包率/等待时间大于预设值,则服务器判定该数据所有者不参与本次训练。具体可参见图2步骤231的相关说明。
在一些实施例中,数据拥有者持有的训练数据可以包括私有数据。在一些实施例中,数据拥有者持有的训练数据可以是保险、银行、医疗至少一个领域中的用户属性信息。在一些实施例中,所述用户属性信息包括图像、文本或语音等。
在一些实施例中,联合训练的模型可以根据样本数据的特征做出预测。在一些实施例中,所述模型还可以用于确认用户的身份信息,所述用户的身份信息可以包括但不限于对用户的信用评价。
本说明书一个或多个实施例中的训练数据可以包括与实体相关的数据。在一些实施例中,实体可以理解为可视化的主体,可以包括但不限于用户、商户等。在一些实施例中,所述训练数据可以包括图像数据、文本数据或声音数据。例如,样本数据中的图像数据可以是商户的logo图像、能够反映用户或商户信息的二维码图像等。例如,训练数据中的文本数据可以是用户的性别、年龄、学历、收入等文本数据,或者是商户的交易商品类型、商户进行商品交易的时间以及所述商品的价格区间等等文本数据。这些数据在联合训练的过程中对其他端都是保密的。例如,训练数据的声音数据可以是包含了用户个人信息或用户反馈的相关语音内容,通过解析所述语音内容可得到对应的用户个人信息或用户反馈信息。
步骤320,服务器通过多方安全计算的方式获取累计数据。在一些实施例中,步骤320可以由累计数据获取模块执行。
在一些实施例中,所述累计数据由所述训练成员基于自身的训练数据及其对应的模型参数通过多方安全计算的方式确定。在一些实施例中,所述累计数据h0的计算方式具体可参见图2步骤231中h0的计算方式。在一些实施例中,所述多方安全计算的方式包括但不限于秘密分享或和共享的方式。
在一些实施例中,服务器可以设定模型训练的迭代次数T。迭代次数可以理解为模型训练的结束条件,即训练完预先设定的迭代次数即可结束训练。在一些实施例中,服务器可以将预先设定的迭代次数T下发给参与训练的各个成员。在一些实施例中,服务器还可以设定预设批处理参数,即在一次迭代训练过程中批量处理的训练样本的个数,可以理解为图2所述的batch_size。在一些实施例中,服务器可以将所述预设批处理参数传输给各个数据持有终端。在一些实施例中,所述预设的迭代次数以及所述预设批处理参数可以用于对各个数据持有终端的参与本次迭代训练的样本数据进行编号。
在一些实施例中,可以由各个数据持有终端基于接收到的预设迭代次数以及预设批处理参数对自身的参与训练的样本数据进行起始编号和终止编号,即对参与训练的样本数据的第一个开始数据和最后一个终止数据分别编号。具体的编号过程可参见图2步骤230。
在一些实施例中,对于由于通信中断而未参与训练的其他成员,在恢复通信连接后,可以将这些成员训练的起始编号和终止编号与当前训练起始编号和终止编号进行比对。若未参与训练的其他成员的终止编号(或起始编号)小于当前训练的终止编号(或起始编号),则更新未参与训练的数据持有终端终止编号(或起始编号),具体描述详见图2步骤231。
步骤330,基于所述累计数据以及样本标签参与累计梯度值的计算。在一些实施例中,步骤330可以由累计梯度值计算模块执行。
在一些实施例中,当各数据持有终端的训练数据为垂直划分时,样本标签可以由其中一个数据持有终端持有;当各数据持有终端的训练数据为水平划分时,每个数据持有终端均持有一个相同的样本标签。服务器接收到数据持有终端发送的样本标签后,可以基于获取的累计数据来参与累计梯度值的计算。其中,累计梯度值的理解可参见图2中的ΔH。
在一些实施例中,联合训练的模型可以包括线性回归模型;也可以包括逻辑回归模型。在一些实施例中,联合训练的模型还可以是深度神经网络模型。
在一些实施例中,当训练模型是深度神经网络模型时,对于层数较多的神经网络,可以将部分网络结构设置在服务器侧,以提高运算效率。在该场景的一些实施例中,当服务器的实际输出与期望输出存在偏差时,服务器基于所述累计数据以及样本标签计算累计损失值,累计损失值通过输出层,按梯度下降的方式依次修正各层权值,向隐层、输入层逐层反向传播。通过所述反向传播,服务器可以基于所述累计损失值计算更新服务器端输入层的梯度。在一些实施例中,服务器计算得到输入层的梯度,即为累计梯度值。具体计算方式可详见图2步骤232的相关描述。
在一些实施例中,累计损失值可以用来反映训练模型预测值与样本数据真实之间的差距。在一些实施例中,累计损失值可以通过参与运算的方式来反映预设值与真实值的差距。在一些实施例中,累计损失值可以通过损失函数进行计算。其中,不同训练模型的相关运算公式不同,相同训练模型时不同参数寻优算法对应的运算公式也不同。本说明书一个或多个实施例并不对损失值的运算公式即损失函数进行限定。
在一些实施例中,当深度网络模型的神经网络层数较少时,可以将网络节点全部设置在各个数据持有终端进行训练。在一些实施例中,当模型为线性回归模型或逻辑回归模型时,也可以将模型参数的训练更新都放在各个数据持有终端进行。
在该场景的一些实施例中,服务器可以计算累计损失值,然后将所述累计损失值发送给各数据持有终端。各数据持有终端根据所述累计损失值以及自身持有的样本数据计算累计梯度值。其中,各数据持有终端可以通过多方安全计算的方式计算累计梯度值。在该场景的其他实施例中,服务器也可以将所述累计数据发送给各个数据持有终端,由各数据持有终端计算累计损失值,以及对应的累计梯度值。
步骤340,训练成员基于累计梯度值计算自身的一阶矩和二阶矩,并基于所述一阶矩和二阶矩完成模型参数的更新。在一些实施例中,步骤340可以由累计梯度值计算模块执行。
在一些实施例中,训练成员Aj基于前述步骤确定的累计梯度值计算一阶矩和二阶矩,然后完成模型参数WAi的计算更新。一阶矩、二阶矩以及模型参数WAi的具体计算过程可参见步骤235的相关描述。
在一些实施例中,当服务器计算累计梯度值时,服务器需要把对应的累计梯度值发给训练成员。在一些实施例中,在一些实施例中,服务器把累计梯度值发给训练成员Aj时,还可以带有样本验证信息,以确保每次参与训练的样本都具有相同的样本编号。例如,可以将训练成员端的迭代次数tAi作为样本验证信息,若训练成员发送的迭代次数tAi与服务器端的迭代次数T相同,则服务器记录此次累计数据;若训练成员发送的迭代次数tAi与服务器端的迭代次数T不同,则服务器不记录此次累计数据并等待训练成员发送新的累计数据。具体可参见图2中步骤231的相关描述。
在一些实施例中,当累计梯度值是由各训练成员Aj计算所得时,各训练成员可直接基于计算得到的累计梯度值代入对应公式计算一阶矩、二阶矩以及对应的模型参数。
在一些实施例中,对于未参与模型参数更新的其他成员,如果其恢复通信连接,服务器也可以基于其他成员的迭代次数差值将所述累计梯度值下发给对应的其他成员。这些恢复通信连接的其他成员基于获取到所述累计梯度值的单独计算一阶矩,二阶矩,并完成模型超参数的更新。其中,所述一阶矩和所述二阶矩分别用于反映所述累计梯度值的期望和方差。计算的一阶矩和二阶矩具体过程可参见步骤235的相关描述。
应当注意的是,上述有关流程300的描述仅仅是为了示例和说明,而不限定本申请的适用范围。对于本领域技术人员来说,在本申请的指导下可以对流程300进行各种修正和改变。然而,这些修正和改变仍在本申请的范围之内。
本申请实施例可能带来的有益效果包括但不限于:(1)多方数据联合训练,提高数据的利用率,提高预测模型的准确性;(2)服务端更新迭代次数并分发给训练成员,减少了通讯的信息量,降低了通讯时间。(3)各训练成员分别独立累积一阶矩和二阶矩的方法保证了权重更新的独立性。(4)部分训练成员通信中断不影响整体训练进度,且通信恢复后可利用迭代次数更新Adam一阶矩和二阶矩的方法,保证训练模型的准确性。需要说明的是,不同实施例可能产生的有益效果不同,在不同的实施例里,可能产生的有益效果可以是以上任意一种或几种的组合,也可以是其他任何可能获得的有益效果。
上文已对基本概念做了描述,显然,对于本领域技术人员来说,上述详细披露仅仅作为示例,而并不构成对本申请的限定。虽然此处并没有明确说明,本领域技术人员可能会对本申请进行各种修改、改进和修正。该类修改、改进和修正在本申请中被建议,所以该类修改、改进、修正仍属于本申请示范实施例的精神和范围。
同时,本申请使用了特定词语来描述本申请的实施例。如“一个实施例”、“一实施例”、和/或“一些实施例”意指与本申请至少一个实施例相关的某一特征、结构或特点。因此,应强调并注意的是,本说明书中在不同位置两次或多次提及的“一实施例”或“一个实施例”或“一个替代性实施例”并不一定是指同一实施例。此外,本申请的一个或多个实施例中的某些特征、结构或特点可以进行适当的组合。
此外,本领域技术人员可以理解,本申请的各方面可以通过若干具有可专利性的种类或情况进行说明和描述,包括任何新的和有用的工序、机器、产品或物质的组合,或对他们的任何新的和有用的改进。相应地,本申请的各个方面可以完全由硬件执行、可以完全由软件(包括固件、常驻软件、微码等)执行、也可以由硬件和软件组合执行。以上硬件或软件均可被称为“数据块”、“模块”、“引擎”、“单元”、“组件”或“系统”。此外,本申请的各方面可能表现为位于一个或多个计算机可读介质中的计算机产品,该产品包括计算机可读程序编码。
计算机存储介质可能包含一个内含有计算机程序编码的传播数据信号,例如在基带上或作为载波的一部分。该传播信号可能有多种表现形式,包括电磁形式、光形式等,或合适的组合形式。计算机存储介质可以是除计算机可读存储介质之外的任何计算机可读介质,该介质可以通过连接至一个指令执行系统、装置或设备以实现通讯、传播或传输供使用的程序。位于计算机存储介质上的程序编码可以通过任何合适的介质进行传播,包括无线电、电缆、光纤电缆、RF、或类似介质,或任何上述介质的组合。
本申请各部分操作所需的计算机程序编码可以用任意一种或多种程序语言编写,包括面向对象编程语言如Java、Scala、Smalltalk、Eiffel、JADE、Emerald、C++、C#、VB.NET、Python等,常规程序化编程语言如C语言、VisualBasic、Fortran2003、Perl、COBOL2002、PHP、ABAP,动态编程语言如Python、Ruby和Groovy,或其他编程语言等。该程序编码可以完全在用户计算机上运行、或作为独立的软件包在用户计算机上运行、或部分在用户计算机上运行部分在远程计算机运行、或完全在远程计算机或处理设备上运行。在后种情况下,远程计算机可以通过任何网络形式与用户计算机连接,比如局域网(LAN)或广域网(WAN),或连接至外部计算机(例如通过因特网),或在云计算环境中,或作为服务使用如软件即服务(SaaS)。
此外,除非权利要求中明确说明,本申请所述处理元素和序列的顺序、数字字母的使用、或其他名称的使用,并非用于限定本申请流程和方法的顺序。尽管上述披露中通过各种示例讨论了一些目前认为有用的发明实施例,但应当理解的是,该类细节仅起到说明的目的,附加的权利要求并不仅限于披露的实施例,相反,权利要求旨在覆盖所有符合本申请实施例实质和范围的修正和等价组合。例如,虽然以上所描述的系统组件可以通过硬件设备实现,但是也可以只通过软件的解决方案得以实现,如在现有的处理设备或移动设备上安装所描述的系统。
同理,应当注意的是,为了简化本申请披露的表述,从而帮助对一个或多个发明实施例的理解,前文对本申请实施例的描述中,有时会将多种特征归并至一个实施例、附图或对其的描述中。但是,这种披露方法并不意味着本申请对象所需要的特征比权利要求中提及的特征多。实际上,实施例的特征要少于上述披露的单个实施例的全部特征。
一些实施例中使用了描述成分、属性数量的数字,应当理解的是,此类用于实施例描述的数字,在一些示例中使用了修饰词“大约”、“近似”或“大体上”来修饰。除非另外说明,“大约”、“近似”或“大体上”表明所述数字允许有±20%的变化。相应地,在一些实施例中,说明书和权利要求中使用的数值参数均为近似值,该近似值根据个别实施例所需特点可以发生改变。在一些实施例中,数值参数应考虑规定的有效数位并采用一般位数保留的方法。尽管本申请一些实施例中用于确认其范围广度的数值域和参数为近似值,在具体实施例中,此类数值的设定在可行范围内尽可能精确。
针对本申请引用的每个专利、专利申请、专利申请公开物和其他材料,如文章、书籍、说明书、出版物、文档等,特此将其全部内容并入本申请作为参考。与本申请内容不一致或产生冲突的申请历史文件除外,对本申请权利要求最广范围有限制的文件(当前或之后附加于本申请中的)也除外。需要说明的是,如果本申请附属材料中的描述、定义、和/或术语的使用与本申请所述内容有不一致或冲突的地方,以本申请的描述、定义和/或术语的使用为准。
最后,应当理解的是,本申请中所述实施例仅用以说明本申请实施例的原则。其他的变形也可能属于本申请的范围。因此,作为示例而非限制,本申请实施例的替代配置可视为与本申请的教导一致。相应地,本申请的实施例不仅限于本申请明确介绍和描述的实施例。
Claims (17)
1.一种基于Adam优化算法的多方联合训练方法;所述方法包括:
基于各个数据持有终端的通信状态确定参与模型参数更新的训练成员;
服务器通过多方安全计算的方式获取累计数据;所述累计数据由所述训练成员基于自身的训练数据及其对应的模型参数确定;
服务器基于所述累计数据以及样本标签参与累计梯度值的计算;所述累计梯度值用于所述训练成员计算自身的一阶矩和二阶矩,并基于所述一阶矩和二阶矩完成模型参数的更新;所述一阶矩和所述二阶矩分别用于反映所述累计梯度值的期望和方差;
其中,所述各个数据终端分别持有自身的训练数据以及与所述训练数据对应的模型参数;所述训练数据包括与实体相关的图像数据、文本数据或声音数据。
2.根据权利要求1所述的方法,所述通信状态包括数据传输的丢包率和/或等待时间;所述基于各个数据持有终端的通信状态确定参与模型参数更新的训练成员包括:
当所述丢包率在预设的丢包率阈值内,和/或所述等待时间在预设的时间阈值内时,将所述完成数据传输的数据持有终端确定为参与模型参数更新的训练成员。
3.根据权利要求1所述的方法,所述方法还包括:
服务器将模型参数更新的预设迭代次数以及预设批处理参数传输给各个数据持有终端;所述预设迭代次数以及所述预设批处理参数用于对各个数据持有终端的训练数据进行编号。
4.根据权利要求3所述的方法,所述方法还包括:
服务器基于所述训练成员当前的迭代次数与所述预设迭代次数之间的迭代次数差值,将与所述迭代次数差值对应个数的累计梯度值传输给所述训练成员。
5.根据权利要求3所述的方法,所述方法还包括:
当通信状态恢复通信连接时,服务器将基于未参与模型参数更新的其他成员的当前迭代次数与所述预设迭代次数之间的迭代次数差值,将与所述迭代次数差值对应个数的累计梯度值传输给所述其他成员;所述累计梯度值用于所述其他成员计算一阶矩和二阶矩。
6.根据权利要求1所述的方法,所述模型包括深度神经网络模型或线性回归模型或逻辑回归模型。
7.根据权利要求6所述的方法,当所述模型包括深度神经网络模型,且所述神经网络模型中部分层的模型参数在服务器端更新时;
所述服务器基于所述累计数据以及样本标签参与累计梯度值的计算包括:
服务器基于所述累计数据以及样本标签计算累计损失值;
服务器基于所述累计损失值,确定所述累计梯度值;并将所述累计梯度值传输给所述训练成员。
8.根据权利要求1所述的方法,所述多方安全计算的方式包括秘密分享。
9.一种基于Adam优化算法的多方联合训练系统;所述系统包括:
训练成员确定模块,用于基于各个数据持有终端的通信状态确定参与模型参数更新的训练成员;
累计数据获取模块,用于通过多方安全计算的方式获取累计数据;所述累计数据由所述训练成员基于自身的训练数据及其对应的模型参数确定;
累计梯度值计算模块,用于基于所述累计数据以及样本标签参与累计梯度值的计算;所述累计梯度值用于所述训练成员计算自身的一阶矩和二阶矩,并基于所述一阶矩和二阶矩完成模型参数的更新;所述一阶矩和所述二阶矩分别用于反映所述累计梯度值的期望和方差;
其中,所述各个数据终端分别持有自身的训练数据以及与所述训练数据对应的模型参数;所述训练数据包括与实体相关的图像数据、文本数据或声音数据。
10.根据权利要求9所述的系统,所述通信状态包括数据传输的丢包率和/或等待时间;
当所述丢包率在预设的丢包率阈值内,和/或所述等待时间在预设的时间阈值内时,所述训练成员确定模块还用于将所述完成数据传输的数据持有终端确定为参与模型参数更新的训练成员。
11.根据权利要求9所述的系统,所述系统还包括:
第一数据传输模块,用于将模型参数更新的预设迭代次数以及预设批处理参数传输给各个数据持有终端;所述预设迭代次数以及所述预设批处理参数用于对各个数据持有终端的训练数据进行编号。
12.根据权利要求11所述的系统,所述系统还包括:
第二数据传输模块,用于基于所述训练成员当前的迭代次数与所述预设迭代次数之间的迭代次数差值,将与所述迭代次数差值对应个数的累计梯度值传输给所述训练成员。
13.根据权利要求11所述的系统,所述系统还包括:
第三数据传输模块,当通信状态恢复通信连接时,用于将基于未参与模型参数更新的其他成员的当前迭代次数与所述预设迭代次数之间的迭代次数差值,将与所述迭代次数差值对应个数的累计梯度值传输给所述其他成员;所述累计梯度值用于所述其他成员计算一阶矩和二阶矩。
14.根据权利要求9所述的系统,所述模型包括深度神经网络模型或线性回归模型或逻辑回归模型。
15.根据权利要求14所述的系统,当所述模型包括深度神经网络模型,且所述神经网络模型中部分层的模型参数在服务器端更新时;所述累计梯度值计算模块还用于:
基于所述累计数据以及样本标签计算累计损失值;基于所述累计损失值,确定所述累计梯度值;并将所述累计梯度值传输给所述训练成员。
16.根据权利要求9所述的系统,所述多方安全计算的方式包括秘密分享。
17.一种基于Adam优化算法的多方联合训练装置,所述装置包括处理器以及存储器;所述存储器用于存储指令,所述处理器用于执行所述指令,以实现如权利要求1至8中任一项所述基于Adam优化算法的多方联合训练方法对应的操作。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010248683.8A CN111460528B (zh) | 2020-04-01 | 2020-04-01 | 一种基于Adam优化算法的多方联合训练方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010248683.8A CN111460528B (zh) | 2020-04-01 | 2020-04-01 | 一种基于Adam优化算法的多方联合训练方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111460528A true CN111460528A (zh) | 2020-07-28 |
CN111460528B CN111460528B (zh) | 2022-06-14 |
Family
ID=71678489
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010248683.8A Active CN111460528B (zh) | 2020-04-01 | 2020-04-01 | 一种基于Adam优化算法的多方联合训练方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111460528B (zh) |
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111898740A (zh) * | 2020-07-31 | 2020-11-06 | 北京达佳互联信息技术有限公司 | 预测模型的模型参数更新方法及装置 |
CN112100295A (zh) * | 2020-10-12 | 2020-12-18 | 平安科技(深圳)有限公司 | 基于联邦学习的用户数据分类方法、装置、设备及介质 |
CN112149158A (zh) * | 2020-08-19 | 2020-12-29 | 成都飞机工业(集团)有限责任公司 | 一种基于同态加密技术的3d打印多数据库共享优化算法 |
CN112288100A (zh) * | 2020-12-29 | 2021-01-29 | 支付宝(杭州)信息技术有限公司 | 一种基于联邦学习进行模型参数更新的方法、系统及装置 |
CN112396191A (zh) * | 2020-12-29 | 2021-02-23 | 支付宝(杭州)信息技术有限公司 | 一种基于联邦学习进行模型参数更新的方法、系统及装置 |
CN112561069A (zh) * | 2020-12-23 | 2021-03-26 | 北京百度网讯科技有限公司 | 模型处理方法、装置、设备、存储介质及产品 |
CN112800466A (zh) * | 2021-02-10 | 2021-05-14 | 支付宝(杭州)信息技术有限公司 | 基于隐私保护的数据处理方法、装置和服务器 |
CN113268727A (zh) * | 2021-07-19 | 2021-08-17 | 天聚地合(苏州)数据股份有限公司 | 联合训练模型方法、装置及计算机可读存储介质 |
Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107315570A (zh) * | 2016-04-27 | 2017-11-03 | 北京中科寒武纪科技有限公司 | 一种用于执行Adam梯度下降训练算法的装置及方法 |
CN109308418A (zh) * | 2017-07-28 | 2019-02-05 | 阿里巴巴集团控股有限公司 | 一种基于共享数据的模型训练方法及装置 |
CN109754060A (zh) * | 2017-11-06 | 2019-05-14 | 阿里巴巴集团控股有限公司 | 一种神经网络机器学习模型的训练方法及装置 |
CN110135573A (zh) * | 2018-02-02 | 2019-08-16 | 阿里巴巴集团控股有限公司 | 一种深度学习模型的训练方法、计算设备以及系统 |
CN110263908A (zh) * | 2019-06-20 | 2019-09-20 | 深圳前海微众银行股份有限公司 | 联邦学习模型训练方法、设备、系统及存储介质 |
CN110276210A (zh) * | 2019-06-12 | 2019-09-24 | 深圳前海微众银行股份有限公司 | 基于联邦学习的模型参数的确定方法及装置 |
CN110288094A (zh) * | 2019-06-10 | 2019-09-27 | 深圳前海微众银行股份有限公司 | 基于联邦学习的模型参数训练方法及装置 |
CN110442457A (zh) * | 2019-08-12 | 2019-11-12 | 北京大学深圳研究生院 | 基于联邦学习的模型训练方法、装置及服务器 |
CN110457984A (zh) * | 2019-05-21 | 2019-11-15 | 电子科技大学 | 监控场景下基于ResNet-50的行人属性识别方法 |
CN110543911A (zh) * | 2019-08-31 | 2019-12-06 | 华南理工大学 | 一种结合分类任务的弱监督目标分割方法 |
US20200019842A1 (en) * | 2019-07-05 | 2020-01-16 | Lg Electronics Inc. | System, method and apparatus for machine learning |
CN110929886A (zh) * | 2019-12-06 | 2020-03-27 | 支付宝(杭州)信息技术有限公司 | 模型训练、预测方法及其系统 |
-
2020
- 2020-04-01 CN CN202010248683.8A patent/CN111460528B/zh active Active
Patent Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107315570A (zh) * | 2016-04-27 | 2017-11-03 | 北京中科寒武纪科技有限公司 | 一种用于执行Adam梯度下降训练算法的装置及方法 |
CN109308418A (zh) * | 2017-07-28 | 2019-02-05 | 阿里巴巴集团控股有限公司 | 一种基于共享数据的模型训练方法及装置 |
CN109754060A (zh) * | 2017-11-06 | 2019-05-14 | 阿里巴巴集团控股有限公司 | 一种神经网络机器学习模型的训练方法及装置 |
CN110135573A (zh) * | 2018-02-02 | 2019-08-16 | 阿里巴巴集团控股有限公司 | 一种深度学习模型的训练方法、计算设备以及系统 |
CN110457984A (zh) * | 2019-05-21 | 2019-11-15 | 电子科技大学 | 监控场景下基于ResNet-50的行人属性识别方法 |
CN110288094A (zh) * | 2019-06-10 | 2019-09-27 | 深圳前海微众银行股份有限公司 | 基于联邦学习的模型参数训练方法及装置 |
CN110276210A (zh) * | 2019-06-12 | 2019-09-24 | 深圳前海微众银行股份有限公司 | 基于联邦学习的模型参数的确定方法及装置 |
CN110263908A (zh) * | 2019-06-20 | 2019-09-20 | 深圳前海微众银行股份有限公司 | 联邦学习模型训练方法、设备、系统及存储介质 |
US20200019842A1 (en) * | 2019-07-05 | 2020-01-16 | Lg Electronics Inc. | System, method and apparatus for machine learning |
CN110442457A (zh) * | 2019-08-12 | 2019-11-12 | 北京大学深圳研究生院 | 基于联邦学习的模型训练方法、装置及服务器 |
CN110543911A (zh) * | 2019-08-31 | 2019-12-06 | 华南理工大学 | 一种结合分类任务的弱监督目标分割方法 |
CN110929886A (zh) * | 2019-12-06 | 2020-03-27 | 支付宝(杭州)信息技术有限公司 | 模型训练、预测方法及其系统 |
Non-Patent Citations (2)
Title |
---|
TRISHUL CHILIMBI: "Project Adam: building an efficient and scalable deep learning training system", 《OSDI"14: PROCEEDINGS OF THE 11TH USENIX CONFERENCE ON OPERATING SYSTEMS DESIGN AND IMPLEMENTATION》 * |
王欢等: "联合多任务学习的人脸超分辨率重建", 《中国图象图形学报》 * |
Cited By (14)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111898740B (zh) * | 2020-07-31 | 2021-07-20 | 北京达佳互联信息技术有限公司 | 预测模型的模型参数更新方法及装置 |
CN111898740A (zh) * | 2020-07-31 | 2020-11-06 | 北京达佳互联信息技术有限公司 | 预测模型的模型参数更新方法及装置 |
CN112149158A (zh) * | 2020-08-19 | 2020-12-29 | 成都飞机工业(集团)有限责任公司 | 一种基于同态加密技术的3d打印多数据库共享优化算法 |
CN112100295A (zh) * | 2020-10-12 | 2020-12-18 | 平安科技(深圳)有限公司 | 基于联邦学习的用户数据分类方法、装置、设备及介质 |
WO2021179720A1 (zh) * | 2020-10-12 | 2021-09-16 | 平安科技(深圳)有限公司 | 基于联邦学习的用户数据分类方法、装置、设备及介质 |
CN112561069B (zh) * | 2020-12-23 | 2021-09-21 | 北京百度网讯科技有限公司 | 模型处理方法、装置、设备及存储介质 |
CN112561069A (zh) * | 2020-12-23 | 2021-03-26 | 北京百度网讯科技有限公司 | 模型处理方法、装置、设备、存储介质及产品 |
CN112396191B (zh) * | 2020-12-29 | 2021-05-11 | 支付宝(杭州)信息技术有限公司 | 一种基于联邦学习进行模型参数更新的方法、系统及装置 |
CN112288100B (zh) * | 2020-12-29 | 2021-08-03 | 支付宝(杭州)信息技术有限公司 | 一种基于联邦学习进行模型参数更新的方法、系统及装置 |
CN112396191A (zh) * | 2020-12-29 | 2021-02-23 | 支付宝(杭州)信息技术有限公司 | 一种基于联邦学习进行模型参数更新的方法、系统及装置 |
CN112288100A (zh) * | 2020-12-29 | 2021-01-29 | 支付宝(杭州)信息技术有限公司 | 一种基于联邦学习进行模型参数更新的方法、系统及装置 |
CN112800466A (zh) * | 2021-02-10 | 2021-05-14 | 支付宝(杭州)信息技术有限公司 | 基于隐私保护的数据处理方法、装置和服务器 |
CN112800466B (zh) * | 2021-02-10 | 2022-04-22 | 支付宝(杭州)信息技术有限公司 | 基于隐私保护的数据处理方法、装置和服务器 |
CN113268727A (zh) * | 2021-07-19 | 2021-08-17 | 天聚地合(苏州)数据股份有限公司 | 联合训练模型方法、装置及计算机可读存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN111460528B (zh) | 2022-06-14 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111460528B (zh) | 一种基于Adam优化算法的多方联合训练方法及系统 | |
CN110929886B (zh) | 模型训练、预测方法及其系统 | |
CN111931950B (zh) | 一种基于联邦学习进行模型参数更新的方法及系统 | |
US11829874B2 (en) | Neural architecture search | |
Liu et al. | A communication efficient collaborative learning framework for distributed features | |
US20200334495A1 (en) | Systems and Methods for Determining Graph Similarity | |
Xu et al. | Personalized course sequence recommendations | |
CN112085159B (zh) | 一种用户标签数据预测系统、方法、装置及电子设备 | |
US11636314B2 (en) | Training neural networks using a clustering loss | |
CN111178547B (zh) | 一种基于隐私数据进行模型训练的方法及系统 | |
CN113011587B (zh) | 一种隐私保护的模型训练方法和系统 | |
US20220391778A1 (en) | Online Federated Learning of Embeddings | |
CN111931876B (zh) | 一种用于分布式模型训练的目标数据方筛选方法及系统 | |
CN113191484A (zh) | 基于深度强化学习的联邦学习客户端智能选取方法及系统 | |
US20200372305A1 (en) | Systems and Methods for Learning Effective Loss Functions Efficiently | |
CN114611720B (zh) | 联邦学习模型训练方法、电子设备及存储介质 | |
CN112799708A (zh) | 联合更新业务模型的方法及系统 | |
WO2021189926A1 (zh) | 业务模型训练方法、装置、系统及电子设备 | |
Cooper | Multidisciplinary flux and multiple research traditions within cognitive science | |
CN110462638A (zh) | 使用后验锐化训练神经网络 | |
US11569278B1 (en) | Systems and methods for callable options values determination using deep machine learning | |
US20190318422A1 (en) | Deep learning approach for assessing credit risk | |
CN110674181B (zh) | 信息推荐方法、装置、电子设备及计算机可读存储介质 | |
WO2022222816A1 (zh) | 用于隐私保护的模型的训练方法、系统及装置 | |
CN112948885A (zh) | 实现隐私保护的多方协同更新模型的方法、装置及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
REG | Reference to a national code |
Ref country code: HK Ref legal event code: DE Ref document number: 40034100 Country of ref document: HK |
|
GR01 | Patent grant | ||
GR01 | Patent grant |