CN114492841A - 一种模型梯度更新方法及装置 - Google Patents

一种模型梯度更新方法及装置 Download PDF

Info

Publication number
CN114492841A
CN114492841A CN202210107380.3A CN202210107380A CN114492841A CN 114492841 A CN114492841 A CN 114492841A CN 202210107380 A CN202210107380 A CN 202210107380A CN 114492841 A CN114492841 A CN 114492841A
Authority
CN
China
Prior art keywords
gradient
nodes
probability
node
value
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210107380.3A
Other languages
English (en)
Inventor
程栋
程新
周雍恺
高鹏飞
姜铁城
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
China Unionpay Co Ltd
Original Assignee
China Unionpay Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by China Unionpay Co Ltd filed Critical China Unionpay Co Ltd
Priority to CN202210107380.3A priority Critical patent/CN114492841A/zh
Publication of CN114492841A publication Critical patent/CN114492841A/zh
Priority to PCT/CN2022/112615 priority patent/WO2023142439A1/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F8/00Arrangements for software engineering
    • G06F8/60Software deployment
    • G06F8/65Updates

Landscapes

  • Engineering & Computer Science (AREA)
  • Software Systems (AREA)
  • Theoretical Computer Science (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computing Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Computer Security & Cryptography (AREA)
  • Information Transfer Between Computers (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请提供一种模型梯度更新方法及装置,用以提高模型训练的准确性。中心服务器重复执行梯度更新过程,直至满足停止条件;其中,一次梯度更新过程包括:接收多个节点分别发送的第一梯度,第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,本次梯度更新过程中的每个节点的概率为Actor‑Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;将第二梯度分别发送给多个节点,以使多个节点采用第二梯度对各自的待训练的模型的权重进行更新。考虑到每个节点的概率,可以优化节点参与度,使确定出的模型更优。

Description

一种模型梯度更新方法及装置
技术领域
本申请涉及模型训练技术领域,特别涉及一种模型梯度更新方法及装置。
背景技术
横向联邦学习也称为按样本划分的联邦学习,可以应用于联邦学习的各个参与方的数据集有相同的特征和不同的样本的场景。
通常假设一个横向联邦学习系统的参与方都是诚实的,需要防范的对象是一个诚实但好奇的中心服务器。即通常假设只有中心服务器才能使得数据参与方的隐私安全受到威胁。在横向联邦学习系统中,具有同样数据特征的多个参与方在中心服务器的帮助下,协作地训练一个模型。主要包括以下步骤:各参与方在本地计算模型梯度,并梯度(梯度需要加密)发送给中心服务器。中心服务器对多个梯度进行聚合。中心服务器将聚合后的梯度(梯度也需要加密)发送给各参与方。各参与方使用接收到的梯度更新各自的模型参数。
上述步骤持续迭代进行,直到损失函数收敛或者达到允许的迭代次数的上限或允许的训练时间,这种架构独立于特定的机器学习算法(如逻辑回归和深度神经网络),并且所有参与方将会共享最终的模型参数。
目前,横向联邦学习场景中,中心服务器对梯度进行平均聚合,考虑到不同参与方的性能不同,采用梯度平均的方式训练出的模型结果不佳。
发明内容
本申请提供一种模型梯度更新的方法及装置,用以提高模型训练的准确性。
为达到上述目的,本申请实施例公开了一种模型梯度更新方法,应用于中心服务器,包括:
中心服务器重复执行梯度更新过程,直至满足停止条件;其中,一次所述梯度更新过程包括:
接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
一种可选的示例中,所述Actor-Critic网络包括Actor网络、至少一个Critic网络、及奖励函数;
所述奖励函数用于基于上一次梯度更新过程中确定的所述多个节点的概率,确定奖励值,并将奖励值传输至所述至少一个Critic网络;
所述至少一个Critic网络用于确定目标Q值,并将所述目标Q值传输至所述Actor网络;
所述Actor网络用于基于所述目标Q值确定本次梯度更新过程中的每个节点的概率。
一种可选的示例中,所述目标Q值为多个Critic网络确定的Q值中的最小Q值。
一种可选的示例中,奖励函数满足:
Figure BDA0003494388330000021
其中,A为第一准确率,B为第二准确率,g大于或等于1,其中,第一准确率为所述中心服务器与所述多个节点基于联邦平均学习算法得到的训练完成的模型的准确率;第二准确率为所述多个节点分别发送的第三准确率的平均值,所述第三准确为与所述第一梯度在所述节点采用样本数据对所述节点中的待训练的模型进行同一次模型训练中得到的。
一种可选的示例中,当
Figure BDA0003494388330000022
大于1时,g大于1;当
Figure BDA0003494388330000023
小于或等于1时,g为1。
一种可选的示例中,所述Actor-Critic网络包括3个Critic网络,针对任一Critic网络,在本次梯度更新过程中确定的Q值基于Q值梯度和上一次梯度更新过程中确定的Q值确定,所述Q值梯度基于第一参数确定,所述第一参数满足以下公式:
Figure BDA0003494388330000031
其中,
Figure BDA0003494388330000032
其中,J为所述第一参数;t为本次梯度更新的次数;k>0,l>0,k+l=1;θ1,θ2,θ3分别表示3个Critic网络,θi为θ1,θ2,θ3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络;st为第t次梯度更新过程中的状态;at为第t次梯度更新过程中所述多个节点的概率;
Figure BDA0003494388330000033
为第t次梯度更新过程中θi对应的Critic网络在st,at情况下确定的Q值;
Figure BDA0003494388330000034
为第t次梯度更新过程中θ3对应的Critic网络在st,at情况下输出的Q值;r(st,at)为第t次梯度更新过程中在st,at情况下的奖励值;γ大于0;πt(at|st)为在st下做出at的概率;q为熵的指数,lnq为熵,αt不为0。
一种可选的示例中,在本次梯度更新过程中采用的α基于α梯度和上一次梯度更新过程中采用的α确定,所述α梯度满足以下公式:
Figure BDA0003494388330000035
其中,J(α)为α梯度,αt-1为上一次梯度更新采用的α,H为理想的最小期望熵。
一种可选的示例中,所述k、l基于所述多个节点中的未进行梯度更新之前的模型的准确率的方差确定。
一种可选的示例中,所述Actor网络在本次梯度更新过程中输出的节点的概率基于概率梯度和上一次梯度更新过程中输出的节点的概率确定,所述概率梯度满足以下公式:
Figure BDA0003494388330000041
其中,J(πφ)为概率梯度,t为本次梯度更新的次数;θ1,θ2,θ3分别表示3个Critic网络,θi为θ1,θ2,θ3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络;st为第t次梯度更新过程中的状态;at为第t次梯度更新过程中所述多个节点的概率;Qθi(st,at)为第t次梯度更新过程中θi对应的Critic网络在st,at情况下确定的Q值;πt(at|st)为在st下做出at的概率;q为熵的指数,lnq为熵,αt为本次梯度更新采用的α,αt不为0。
一种可选的示例中,所述中心服务器和所述多个节点基于联邦学习架构进行梯度更新。
本申请实施例提供了一种模型梯度更新装置,包括:
重复执行梯度更新过程,直至满足停止条件;其中,一次所述梯度更新过程包括:
接收模块,用于接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;
处理模块,用于基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;
发送模块,用于将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
本申请实施例提供了一种模型梯度更新装置,包括处理器和存储器;
所述存储器,用于存储计算机程序或指令;
所述处理器,用于执行所述存储器中的部分或者全部计算机程序或指令,当所述部分或者全部计算机程序或指令被执行时,用于实现上述任一项所述的模型梯度更新方法。
本申请实施例提供了一种计算机可读存储介质,用于存储计算机程序,所述计算机程序包括用于实现任一项所述的模型梯度更新方法的指令。
本申请考虑到每个节点的概率,可以优化节点参与度,使确定出的模型更优。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本申请提供的一种模型梯度更新流程示意图;
图2为本申请提供的一种模型梯度更新系统架构图;
图3为本申请提供的一种模型梯度更新装置结构图;
图4为本申请提供的一种模型梯度更新装置结构图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
中心服务器重复执行梯度更新过程,直至满足停止条件。停止条件例如为损失函数收敛、或者达到允许的迭代次数的上限、或达到允许的训练时间。
接下来如图1所示,对任一次所述梯度更新过程进行介绍:
步骤101:多个节点(例如节点1、节点2、……节点m)分别向中心服务器发送第一梯度。相应的,中心服务器接收所述多个节点分别发送的第一梯度。
所述第一梯度为每个节点采用样本数据对所述节点中的待训练的模型进行一次或多次训练得到。将节点向中心服务器发送的梯度称为第一梯度,多个节点发送的第一梯度可能是相同的,也可能是不同的。
节点可以是汽车终端,样本数据可以是在自动驾驶中产生的数据。汽车行驶中生成不同的行驶数据,行驶数据分散在各个节点(汽车终端),并存在数据质量不均衡,节点性能不均衡的特点。模型可以是需要进行自动驾驶、用户习惯相关的模型。
步骤102:Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率,确定本次梯度更新过程中的每个节点的概率。
可以理解的是,如果本次梯度更新为第一次梯度更新,则上一次梯度更新过程中的每个节点的概率为初始设定的每个节点的概率。
多个节点的概率之和可以是1。
Actor-Critic网络可以每循环一次就输出多个节点的概率,也可以是循环多次才输出节点的概率。
Actor-Critic网络可以是在中心服务器上,也可以不在所述中心服务器上。
步骤103:中心服务器基于多个第一梯度和多个概率得到第二梯度。
在基于多个第一梯度和多个概率得到第二梯度时,可以将多个第一梯度和多个概率用加权平均的方式确定第二梯度,例如节点1、节点2、节点3分别发送的第一梯度为p1、p2、p3,三个节点的概率为0.2、0.4、0.4,则第二梯度为0.2p1+0.4p2+0.4p3的取值。该过程也可以称为基于深度强化学习数据融合算法,多个第一梯度和多个概率得到第二梯度。
步骤104:中心服务器将所述第二梯度分别发送给所述多个节点(例如节点1、节点2、……节点m),相应的,多个节点接收来自中心服务器的第二梯度。
将中心服务器向节点发送的梯度称为第二梯度,向多个节点发送的第二梯度是相同的。
步骤105:多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
如果满足停止条件,则更新后的模型即为训练完成的模型。
如果未满足停止条件,则每个节点可以采用样本数据对所述节点中的待训练的模型进行一次或多次训练,得到新的第一梯度,继续重复执行步骤101-步骤105。
考虑到每个节点的概率,可以优化节点参与度,使确定出的模型更优。
可选的,所述中心服务器和所述多个节点基于联邦学习架构进行梯度更新。各个节点的样本数据私有,与其它节点和中心服务器不共享,并且节点和中心服务器在传输第一梯度和第二梯度时,第一梯度和第二梯度是加密的。
如图2所示,介绍一种梯度更新系统架构图,包括多个节点、中心服务器、和Actor-Critic网络。Actor-Critic网络可以位于中心服务器上,也可以不位于中心服务器上。Actor-Critic从名字上看包括两部分,演员(Actor)和评价者(Critic)。其中Actor负责生成动作(Action)并和环境交互。而Critic负责评估Actor的表现,并指导Actor下一阶段的动作。
所述Actor-Critic网络包括Actor网络、至少一个Critic网络,及奖励函数;
所述奖励函数用于基于上一次梯度更新过程中确定的所述多个节点的概率,确定奖励值,并将奖励值传输至所述至少一个Critic网络;
所述至少一个Critic网络用于确定目标Q值,并将目标Q值传输至所述Actor网络。每个Critic网络确定出一个Q值,如果有一个Critic网络,则该Critic网络确定出的Q值即为目标Q值,如果有多个Critic网络,则可以在多个Q值选择出一个Q值作为目标Q值。例如目标Q值为多个Critic网络确定的Q值中的最小Q值。当有多个Critic网络时,相当于设置了多个评价者,评估Actor的表现更加准确,使Actor做出的动作更加准确,进而得出的多个节点的概率符合多个节点的性能情况。
所述Actor网络基于所述目标Q值确定本次梯度更新过程中的每个节点的概率。并将所述多个节点的概率传输至所述奖励函数,多次循环,直至停止。
本申请可以采用现有的Actor-Critic网络确定概率,也可以对现有的Actor-Critic网络进行改进,例如设置多个Critic网络,例如对Critic网络中涉及的算法进行改进,例如对Actor网络中涉及的算法进行改进,例如对奖励函数进行改进,可以理解的是,改进只是涉及到具体的细节,改进后的Actor-Critic网络与现有的Actor-Critic网络的运行机制是类似的。
结合图2介绍的系统,对本申请的梯度更新过程进行详细介绍。
可以先对Actor-Critic网络中涉及的参数进行初始化,包括但不限于:对Critic网络中涉及的参数,对Actor网络中涉及的参数,对奖励函数中涉及的参数进行初始化。
中心服务器与多个节点基于联邦平均学习算法得到的训练完成的模型,并确定所述训练完成的模型的第一准确率A。
节点1、节点2……、节点m基于当前保存的模型(也可以称为待训练的模型,例如可以是还执行步骤101-步骤105的梯度更新过程的模型,也可以是已经执行过一次或多次步骤101-步骤105的梯度更新过程的模型)进行一次或多次训练得到第一梯度及第三准确率B’,各个节点向中心服务器发送第一梯度及第三准确率B’,第三准确率与第一梯度在节点进行同一次模型训练中得到的。
中心服务器对多个第三准确率B’计算平均值,得到第二准确率B。
奖励函数基于第一准确率A和第二准确率B确定奖励值r,第一准确率为所述中心服务器与所述多个节点基于联邦学习得到的训练完成的模型的准确率;第二准确率为所述多个节点分别发送的第三准确率的平均值,所述第三准确为与所述第一梯度在所述节点采用样本数据对所述节点中的待训练的模型进行同一次模型训练中得到的。
例如奖励函数表示为:
Figure BDA0003494388330000081
结果B/结果A越高,奖励值r越高。可以理解的是,在任一次梯度更新过程中,A的取值都是相同的,B的取值可能相同,也可能不同。g大于或等于1。可选的,本申请设置2个奖励函数,分别为:当
Figure BDA0003494388330000082
大于1时,g取值为大于1的常数,可以进行强引导,以更快的完成梯度训练;当
Figure BDA0003494388330000091
小于或等于1时,g设置为1。
所述Critic网络本次梯度更新过程中确定的Q值(即更新后的Q值)基于Q值梯度和上一次梯度更新过程中确定的Q值(即更新前的Q值)确定,例如,更新后的Q值=更新前的Q值+Q值梯度。可以理解的是,如果本次梯度更新为第一次梯度更新,则上一次输出的Q值为初始设定的Q值。
一种示例中,Q值梯度基于第一算法和第二算法确定,其中,第一算法具有在训练中偏向于一个特定的动作的特性,第二算法具有在训练中均衡选择各个工作的特性。例如,第一算法可以是深度确定性策略梯度算法(Deep Deterministic Policy Gradient,DDPG)算法,第二算法可以是SAC算法,SAC算法可以是具有自动熵调节功能的SAC(Soft actorcritic with automatic entropy adjustment,SAC-AEA)强化学习算法。DDPG受制于算法本身的更新策略,训练后期会偏向于一个特定的动作(action),这不利于调度多个节点的概率,对实现整体性模型融合是不利的,会导致最后训练出来的模型与某个节点的数据高度相关,其他节点的模型数据对模型结果贡献度变低,也就是对于多方数据的利用效率大大降低,甚至会导致模型训练结果不佳,或出现过拟合等问题。SAC-AEA强化学习算法,本身可以较为均衡选择各个动作(action),然而对于实际的联邦学习框架,各个节点的数据质量、对模型的贡献度、本地算力(例如本地设备的计算效率)等都不同(在本申请中可以将他们表述为优势方和非优势方)的情况下,均衡地融合他们的数据显然不利于模型训练结果的提升的,或出现欠拟合,不能完全表征完整的数据特征。本申请中所述Critic网络确定的Q值基于DDPG算法和SAC算法更新。融合DDPG算法和SAC算法,可以结合两个算法的优势,使训练出的模型融合多个节点的性能,模型较优。本申请可以对DDPG算法设置第一权重,对SAC算法设置第二权重,基于DDPG算法及第一权重、SAC算法及第二权重确定Q值梯度,所述第一权重和所述第二权重基于所述多个节点中的未进行梯度更新之前的模型的准确率的方差确定。
一种示例中,在Critic网络中,提出了基于复合自适应可调权重的Q值更新算法;Q值梯度基于第一参数J确定,例如Q值梯度为第一参数J与步长的乘积,步长不为0。以3个Critic网络为例,第一参数J满足以下公式:
Figure BDA0003494388330000101
Figure BDA0003494388330000102
或者,
Figure BDA0003494388330000103
其中,
Figure BDA0003494388330000104
该示例以Actor-Critic网络可以每循环一次就输出多个节点的概率为例进行介绍,t为梯度更新的次数,例如,第t次梯度更新过程,t为大于或等于1的整数。
k>0,l>0,k+l=1;k和l可以是固定值,可以是人为设置的数值,还可以是基于所述多个节点中的未进行梯度更新之前的模型的准确率的方差确定;准确率的方差可以在一定程度上表示节点的性能(例如算力)差异等。方差越大,节点的性能差异越大,反之,方差越小,多个节点的性能差异越小。
θ1,θ2,θ3分别表示3个Critic网络,θi为θ1,θ2,θ3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络。可以理解的是,在本次梯度更新过程中,3个Critic网络最新确定出的Q值是指在上一次梯度更新(如果本次梯度更新为第t次,则上一次为第t-1次)过程中确定出的Q值;
st为第t次梯度更新过程中的状态;st可以是各个节点的准确率或者准确率平均值,可以是各个节点的梯度或者梯度平均值,可以是各个节点的方差等;
at为第t次梯度更新过程中所述多个节点的概率;at也可以称为动作;
Figure BDA0003494388330000105
为第t次梯度更新过程中θi对应的Critic网络在st,at情况下确定的Q值;
Figure BDA0003494388330000111
为第t次梯度更新过程中θ3对应的Critic网络在st,at情况下输出的Q值;
r(st,at)为第t次梯度更新过程中在st,at情况下的奖励值;
γ为衰减因子,γ大于0;
πt(at|st)为条件概率,πt(at|st)为在st下做出at的概率;
q为熵(例如Tasslis熵)的指数,q为大于或等于1的整数,可以是1或2或3,lnq为熵,是一个曲线。当q不同时,lnq为一个曲线族;
E为数字期望,E对[*]内的数据求期望,自变量为st和at,上述公式中的[]内的内容是对st和at的隐式表达;
D为记忆库(也可以称为经验回放池、或缓存空间),(s,a)~D是指记忆库中的s,a;假设记忆库D中可以存储M个s及M个a,记忆库D为循环覆盖,当则Actor-Critic网络可以先自循环M次,得到M个s及M个a,在第M+1次循环时Actor-Critic网络才输出节点的概率,前面的M次循环中Actor-Critic网络不输出节点的概率,或者说前面的M次循环中Actor-Critic网络输出的节点的概率可以忽略不计。
a~πφ表示πφ基于a确定,例如多个a组成πφ曲线(或集合)。
αt为第t次梯度更新过程中采用的α,α可以是个固定值,不为0即可,α也可以是变量(即自适应参数),本次梯度更新采用的(更新后的)α基于α梯度和上一次梯度更新采用的(更新前的)α确定。例如,本次梯度更新采用的(更新后的)α=α梯度和上一次梯度更新采用的(更新前的)α。可以理解的是,如果本次梯度更新为第一次梯度更新,则上一次采用的α为初始设定的α。α梯度满足以下公式:
Figure BDA0003494388330000112
其中,J(α)为α梯度,αt-1为上一次梯度更新采用的α,H为理想的最小期望熵,熵例如为Tasslis熵。
通过复合自适应可调权重的Q值更新方法,可以自适应调整梯度权重参数,可以根据实际场景,只需调整权重参数,对于进行特定动作action选取,或均衡动作选取,从而可以根据具体情况调度各参与方的模型信息。
所述Actor网络本次输出的节点的概率(即更新后的节点的概率)为上一次输出的节点的概率(即更新前的节点的概率)与概率梯度的和值。
在Actor网络中融入Tasslis熵概念和自适应参数,概率梯度基于以下公式确定:
Figure BDA0003494388330000121
其中,J(πφ)为概率梯度,αt为本次(第t次)梯度更新过程得到的α;θi为θ1,θ2,θ3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络。可以理解的是,在本次梯度更新过程中,3个Critic网络最新确定出的Q值是指在本次梯度更新(例如第t次梯度更新)过程中确定出的Q值。
在联邦框架中的服务器端,调整了不同节点的数据融合算法,设计了将深度强化学习数据融合模型与联邦平均算法相结合的优化融合策略。可以调整不同节点的参与度和训练时的数据利用程度。
前文介绍了本申请实施例的方法,下文中将介绍本申请实施例中的装置。方法、装置是基于同一技术构思的,由于方法、装置解决问题的原理相似,因此装置与方法的实施可以相互参见,重复之处不再赘述。
本申请实施例可以根据上述方法示例,对装置进行功能模块的划分,例如,可以对应各个功能划分为各个功能模块,也可以将两个或两个以上的功能集成在一个模块中。这些模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。需要说明的是,本申请实施例中对模块的划分是示意性的,仅仅为一种逻辑功能划分,具体实现时可以有另外的划分方式。
基于与上述方法的同一技术构思,参见图3,提供了一种模型梯度更新装置,包括:
接收模块301,用于接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;
处理模块302,用于基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;
发送模块303,用于将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
以上过程为一次所述梯度更新过程,重复执行梯度更新过程,直至满足停止条件。
基于与上述方法的同一技术构思,参见图4,提供了一种模型梯度更新装置,包括处理器401和存储器402,可选的,还包括收发器403;
所述存储器402,用于存储计算机程序或指令;
所述处理器401,用于执行所述存储器中的部分或者全部计算机程序或指令,当所述部分或者全部计算机程序或指令被执行时,用于实现上述任一项所述的模型梯度更新方法。例如收发器403执行接收和发送动作,处理器401执行处接收和发送动作外的其它动作。
本申请实施例提供了一种计算机可读存储介质,用于存储计算机程序,所述计算机程序包括用于实现任一项所述的模型梯度更新方法的指令。
本申请实施例还提供了一种计算机程序产品,包括:计算机程序代码,当所述计算机程序代码在计算机上运行时,使得计算机可以执行上述提供的模型梯度更新的方法。
本申请实施例还提供了一种通信的系统,所述通信系统包括:执行上述模型梯度更新的方法的节点和中心服务器。
另外,本申请实施例中提及的处理器可以是中央处理器(central processingunit,CPU),基带处理器,基带处理器和CPU可以集成在一起,或者分开,还可以是网络处理器(network processor,NP)或者CPU和NP的组合。处理器还可以进一步包括硬件芯片或其他通用处理器。上述硬件芯片可以是专用集成电路(application-specific integratedcircuit,ASIC),可编程逻辑器件(programmable logic device,PLD)或其组合。上述PLD可以是复杂可编程逻辑器件(complex programmable logic device,CPLD),现场可编程逻辑门阵列(field-programmable gate array,FPGA),通用阵列逻辑(generic array logic,GAL)及其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等或其任意组合。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
本申请实施例中提及的存储器可以是易失性存储器或非易失性存储器,或可包括易失性和非易失性存储器两者。其中,非易失性存储器可以是只读存储器(Read-OnlyMemory,ROM)、可编程只读存储器(Programmable ROM,PROM)、可擦除可编程只读存储器(Erasable PROM,EPROM)、电可擦除可编程只读存储器(Electrically EPROM,EEPROM)或闪存。易失性存储器可以是随机存取存储器(Random Access Memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(Static RAM,SRAM)、动态随机存取存储器(Dynamic RAM,DRAM)、同步动态随机存取存储器(Synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(Double Data RateSDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(Enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(Synchlink DRAM,SLDRAM)和直接内存总线随机存取存储器(DirectRambus RAM,DR RAM)。应注意,本申请描述的存储器旨在包括但不限于这些和任意其它适合类型的存储器。
本申请实施例中提及的收发器中可以包括单独的发送器,和/或,单独的接收器,也可以是发送器和接收器集成一体。收发器可以在相应的处理器的指示下工作。可选的,发送器可以对应物理设备中发射机,接收器可以对应物理设备中的接收机。
本领域普通技术人员可以意识到,结合本文中所公开的实施例中描述的各方法步骤和单元,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各实施例的步骤及组成。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。本领域普通技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统、装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口、装置或单元的间接耦合或通信连接,也可以是电的,机械的或其它的形式连接。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本申请实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以是两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分,或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(read-only memory,ROM)、随机存取存储器(random access memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
本申请中的“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。本申请中所涉及的多个,是指两个或两个以上。另外,需要理解的是,在本申请的描述中,“第一”、“第二”等词汇,仅用于区分描述的目的,而不能理解为指示或暗示相对重要性,也不能理解为指示或暗示顺序。
尽管已描述了本申请的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本申请范围的所有变更和修改。
显然,本领域的技术人员可以对本申请实施例进行各种改动和变型而不脱离本申请实施例的精神和范围。这样,倘若本申请实施例的这些修改和变型属于本申请权利要求及其等同技术的范围之内,则本申请也意图包括这些改动和变型在内。

Claims (13)

1.一种模型梯度更新方法,其特征在于,应用于中心服务器,包括:
中心服务器重复执行梯度更新过程,直至满足停止条件;其中,一次所述梯度更新过程包括:
接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
2.如权利要求1所述的方法,其特征在于,所述Actor-Critic网络包括Actor网络、至少一个Critic网络、及奖励函数;
所述奖励函数用于基于上一次梯度更新过程中确定的所述多个节点的概率,确定奖励值,并将奖励值传输至所述至少一个Critic网络;
所述至少一个Critic网络用于确定目标Q值,并将所述目标Q值传输至所述Actor网络;
所述Actor网络用于基于所述目标Q值确定本次梯度更新过程中的每个节点的概率。
3.如权利要求2所述的方法,其特征在于,所述目标Q值为多个Critic网络确定的Q值中的最小Q值。
4.如权利要求2所述的方法,其特征在于,奖励函数满足:
Figure FDA0003494388320000011
其中,A为第一准确率,B为第二准确率,g大于或等于1,其中,第一准确率为所述中心服务器与所述多个节点基于联邦平均学习算法得到的训练完成的模型的准确率;第二准确率为所述多个节点分别发送的第三准确率的平均值,所述第三准确为与所述第一梯度在所述节点采用样本数据对所述节点中的待训练的模型进行同一次模型训练中得到的。
5.如权利要求4所述的方法,其特征在于,当
Figure FDA0003494388320000021
大于1时,g大于1;当
Figure FDA0003494388320000022
小于或等于1时,g为1。
6.如权利要求2所述的方法,其特征在于,所述Actor-Critic网络包括3个Critic网络,针对任一Critic网络,在本次梯度更新过程中确定的Q值基于Q值梯度和上一次梯度更新过程中确定的Q值确定,所述Q值梯度基于第一参数确定,所述第一参数满足以下公式:
Figure FDA0003494388320000023
其中,
Figure FDA0003494388320000024
其中,J为所述第一参数;t为本次梯度更新的次数;k>0,l>0,k+l=1;θ1,θ2,θ3分别表示3个Critic网络,θi为θ1,θ2,θ3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络;st为第t次梯度更新过程中的状态;at为第t次梯度更新过程中所述多个节点的概率;
Figure FDA0003494388320000025
为第t次梯度更新过程中θi对应的Critic网络在st,at情况下确定的Q值;
Figure FDA0003494388320000026
为第t次梯度更新过程中θ3对应的Critic网络在st,at情况下输出的Q值;r(st,at)为第t次梯度更新过程中在st,at情况下的奖励值;γ大于0;πt(at|st)为在st下做出at的概率;q为熵的指数,lnq为熵,αt不为0。
7.如权利要求6所述的方法,其特征在于,在本次梯度更新过程中采用的α基于α梯度和上一次梯度更新过程中采用的α确定,所述α梯度满足以下公式:
Figure FDA0003494388320000027
其中,J(α)为α梯度,αt-1为上一次梯度更新采用的α,H为理想的最小期望熵。
8.如权利要求6所述的方法,其特征在于,所述k、l基于所述多个节点中的未进行梯度更新之前的模型的准确率的方差确定。
9.如权利要求2所述的方法,其特征在于,所述Actor网络在本次梯度更新过程中输出的节点的概率基于概率梯度和上一次梯度更新过程中输出的节点的概率确定,所述概率梯度满足以下公式:
Figure FDA0003494388320000031
其中,J(πφ)为概率梯度,t为本次梯度更新的次数;θ1,θ2,θ3分别表示3个Critic网络,θi为θ1,θ2,θ3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络;st为第t次梯度更新过程中的状态;at为第t次梯度更新过程中所述多个节点的概率;Qθi(st,at)为第t次梯度更新过程中θi对应的Critic网络在st,at情况下确定的Q值;πt(at|st)为在st下做出at的概率;q为熵的指数,lnq为熵,αt为本次梯度更新采用的α,αt不为0。
10.如权利要求1-9任一项所述的方法,其特征在于,所述中心服务器和所述多个节点基于联邦学习架构进行梯度更新。
11.一种模型梯度更新装置,其特征在于,包括:
重复执行梯度更新过程,直至满足停止条件;其中,一次所述梯度更新过程包括:
接收模块,用于接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;
处理模块,用于基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;
发送模块,用于将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
12.一种模型梯度更新装置,其特征在于,包括处理器和存储器;
所述存储器,用于存储计算机程序或指令;
所述处理器,用于执行所述存储器中的部分或者全部计算机程序或指令,当所述部分或者全部计算机程序或指令被执行时,用于实现如权利要求1-10任一项所述的方法。
13.一种计算机可读存储介质,其特征在于,用于存储计算机程序,所述计算机程序包括用于实现权利要求1-10任一项所述的方法的指令。
CN202210107380.3A 2022-01-28 2022-01-28 一种模型梯度更新方法及装置 Pending CN114492841A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202210107380.3A CN114492841A (zh) 2022-01-28 2022-01-28 一种模型梯度更新方法及装置
PCT/CN2022/112615 WO2023142439A1 (zh) 2022-01-28 2022-08-15 一种模型梯度更新方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210107380.3A CN114492841A (zh) 2022-01-28 2022-01-28 一种模型梯度更新方法及装置

Publications (1)

Publication Number Publication Date
CN114492841A true CN114492841A (zh) 2022-05-13

Family

ID=81477080

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210107380.3A Pending CN114492841A (zh) 2022-01-28 2022-01-28 一种模型梯度更新方法及装置

Country Status (2)

Country Link
CN (1) CN114492841A (zh)
WO (1) WO2023142439A1 (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2023142439A1 (zh) * 2022-01-28 2023-08-03 中国银联股份有限公司 一种模型梯度更新方法及装置

Family Cites Families (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11928598B2 (en) * 2019-10-24 2024-03-12 Alibaba Group Holding Limited Method and system for distributed neural network training
CN113282933B (zh) * 2020-07-17 2022-03-01 中兴通讯股份有限公司 联邦学习方法、装置和系统、电子设备、存储介质
CN112087518B (zh) * 2020-09-10 2022-10-21 中国工商银行股份有限公司 用于区块链的共识方法、装置、计算机系统和介质
CN112818394A (zh) * 2021-01-29 2021-05-18 西安交通大学 具有本地隐私保护的自适应异步联邦学习方法
CN113971089A (zh) * 2021-09-27 2022-01-25 国网冀北电力有限公司信息通信分公司 联邦学习系统设备节点选择的方法及装置
CN114492841A (zh) * 2022-01-28 2022-05-13 中国银联股份有限公司 一种模型梯度更新方法及装置

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2023142439A1 (zh) * 2022-01-28 2023-08-03 中国银联股份有限公司 一种模型梯度更新方法及装置

Also Published As

Publication number Publication date
WO2023142439A1 (zh) 2023-08-03

Similar Documents

Publication Publication Date Title
EP4078899A1 (en) Systems and methods for enhanced feedback for cascaded federated machine learning
CN111784002B (zh) 分布式数据处理方法、装置、计算机设备及存储介质
WO2021259090A1 (zh) 联邦学习的方法、装置和芯片
JP7383803B2 (ja) 不均一モデルタイプおよびアーキテクチャを使用した連合学習
US11797864B2 (en) Systems and methods for conditional generative models
US20220383200A1 (en) Method and apparatus for constructing multi-task learning model, electronic device, and storage medium
CN110770761A (zh) 深度学习系统和方法以及使用深度学习的无线网络优化
CN116745780A (zh) 用于去中心化联邦学习的方法和系统
CN112926747B (zh) 优化业务模型的方法及装置
CN113873534B (zh) 一种雾计算中区块链协助的联邦学习主动内容缓存方法
US11871251B2 (en) Method of association of user equipment in a cellular network according to a transferable association policy
US20220318412A1 (en) Privacy-aware pruning in machine learning
CN114492841A (zh) 一种模型梯度更新方法及装置
US10952120B1 (en) Online learning based smart steering system for wireless mesh networks
Kim et al. Learning to cooperate in decentralized wireless networks
CN114760308A (zh) 边缘计算卸载方法及装置
US20230153633A1 (en) Moderator for federated learning
WO2023061500A1 (en) Methods and systems for updating parameters of a parameterized optimization algorithm in federated learning
CN116010832A (zh) 联邦聚类方法、装置、中心服务器、系统和电子设备
Wu et al. Model-heterogeneous federated learning with partial model training
CN111967612A (zh) 横向联邦建模优化方法、装置、设备及可读存储介质
CN113806691B (zh) 一种分位数的获取方法、设备及存储介质
WO2023225552A1 (en) Decentralized federated learning using a random walk over a communication graph
CN117151206B (zh) 一种多智能体协同决策强化学习方法、系统及装置
US11076289B1 (en) AI-based multi-mode wireless access protocol (MMWAP)

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination