CN115600693A - 机器学习模型训练方法、识别方法、相关装置及电子设备 - Google Patents

机器学习模型训练方法、识别方法、相关装置及电子设备 Download PDF

Info

Publication number
CN115600693A
CN115600693A CN202211282362.5A CN202211282362A CN115600693A CN 115600693 A CN115600693 A CN 115600693A CN 202211282362 A CN202211282362 A CN 202211282362A CN 115600693 A CN115600693 A CN 115600693A
Authority
CN
China
Prior art keywords
local
model
global
training
parameter
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Withdrawn
Application number
CN202211282362.5A
Other languages
English (en)
Inventor
刘吉
贠瑜晖
窦德景
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202211282362.5A priority Critical patent/CN115600693A/zh
Publication of CN115600693A publication Critical patent/CN115600693A/zh
Withdrawn legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Image Analysis (AREA)

Abstract

本公开提供了一种机器学习模型训练方法、识别方法、相关装置、电子设备、存储介质及计算机程序产品,涉及人工智能技术领域,尤其涉及联邦学习技术领域。具体实现方案为:接收对全局模型的更新指令,所述指令用于触发本地模型的训练;响应所述更新指令,基于本地数据集,得到所述本地模型的本地动量参数;根据所述本地动量参数对所述本地模型进行训练,得到所述本地模型的本地目标参数;发送所述本地目标参数,所述本地目标参数用于供第二设备对所述全局模型进行参数更新,完成参数更新的全局模型用于识别待识别数据中的目标对象。本公开为训练更精准的联邦学习模型提供了一种技术支持。

Description

机器学习模型训练方法、识别方法、相关装置及电子设备
技术领域
本公开涉及人工智能技术领域,尤其涉及联邦学习技术领域,更具体的涉及机器学习模型训练方法、识别方法、相关装置、电子设备、存储介质及计算机程序产品。
背景技术
作为一种机器学习模型,联邦学习模型因其具有对数据的隐私性保护而更受业内欢迎。但,目前联邦学习模型的使用效果不足,如识别准确度欠佳。
发明内容
本公开提供了一种机器学习模型训练方法、识别方法、相关装置、电子设备、存储介质及计算机程序产品。
根据本公开的一方面,提供了一种机器学习模型训练方法,应用于第一设备中,所述方法包括:
接收对全局模型的更新指令,所述指令用于触发本地模型的训练;
响应所述更新指令,基于本地数据集,得到所述本地模型的本地动量参数;
根据所述本地动量参数对所述本地模型进行训练,得到所述本地模型的本地目标参数;
发送所述本地目标参数,所述本地目标参数用于供第二设备对所述全局模型进行参数更新,完成参数更新的全局模型用于识别待识别数据中的目标对象。
根据本公开的另一方面,提供了一种机器学习模型训练方法,应用于第二设备中,所述方法包括:
接收针对全局模型的更新指令而产生的多个本地目标参数,所述多个本地目标参数为各第一设备根据前述的应用于第一设备的机器学习模型训练方法而得到;
基于对所述多个本地目标参数的聚合结果,得到全局模型的全局动量参数;
采用全局动量参数,对所述全局模型进行参数更新,得到目标全局模型,所述目标全局模型用于识别待识别数据中的目标对象。
根据本公开的再一方面,提供了一种识别方法,包括:
获得待识别数据;
将所述待识别数据输入至目标全局模型,得到所述待识别数据中的目标对象。
根据本公开的又一方面,提供了一种机器学习模型训练装置,包括:
接收单元,用于接收对全局模型的更新指令,所述指令用于触发本地模型的训练;
第一获得单元,用于响应所述更新指令,基于本地数据集,得到所述本地模型的本地动量参数;
第二获得单元,用于根据所述本地动量参数对所述本地模型进行训练,得到所述本地模型的本地目标参数;
发送单元,用于发送所述本地目标参数,所述本地目标参数用于供第二设备对所述全局模型进行参数更新,完成参数更新的全局模型用于识别待识别数据中的目标对象。
根据本公开的又一方面,提供了一种机器学习模型训练装置,包括:
接收模块,用于接收针对全局模型的更新指令而产生的多个本地目标参数;
第一获得模块,用于基于对所述多个本地目标参数的聚合结果,得到全局模型的全局动量参数;
第二获得模块,用于采用全局动量参数,对所述全局模型进行参数更新,得到目标全局模型,所述目标全局模型用于识别待识别数据中的目标对象。
根据本公开的又一方面,提供了一种识别装置,包括:
第一获得单元,用于获得待识别数据;
第二获得单元,用于将所述待识别数据输入目标全局模型,得到所述待识别数据中的目标对象。
根据本公开的又一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开任一实施例中的方法。
根据本公开的又一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,该计算机指令用于使计算机执行本公开任一实施例中的方法。
根据本公开的又一方面,提供了一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现本公开任一实施例中的方法。
本公开为训练更精准的联邦学习模型提供了一种技术支持。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是本公开实施例的联邦学习场景示意图;
图2是本公开实施例的应用于第一设备的机器学习模型训练方法的流程示意图一;
图3是本公开实施例的应用于第一设备的机器学习模型训练方法的流程示意图二;
图4是本公开实施例的应用于第二设备的机器学习模型训练方法的流程示意图一;
图5是本公开实施例的应用于第二设备的机器学习模型训练方法的流程示意图二;
图6是本公开实施例的机器学习模型训练装置一的组成示意图;
图7是本公开实施例的机器学习模型训练装置二的组成示意图;
图8是本公开实施例的识别装置的组成示意图;
图9是用来实现本公开实施例的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
本领域技术人员应该而知,通常,在联邦学习方式中,联邦学习架构包括有两种:一种是中心化联邦(客户端Client/服务器Server)架构,一种是去中心化联邦架构,即对等计算架构。
其中,结合图1所示,对于Client/Server架构而言,通常被应用在联合多方用户(边缘设备)的联合学习场景中。以待训练的模型为机器学习模型为例,在Client/Server架构中机器学习模型的训练原理是:服务器需要先把待训练的机器学习模型发送至选中的两个或两个以上的边缘设备,如图1所示的N个边缘设备中的至少两个设备,其中N为大于等于2的正整数。各被选中的边缘设备采用本地数据集(本地训练集)对待训练的机器学习模型进行本地训练,并将本地训练结果如本地训练出的模型参数发送至服务器。服务器对各被选中的边缘设备训练出的模型参数进行服务器侧的更新,得到服务器期望的机器学习模型。可见,服务器期望的机器学习模型是由被选中的边缘设备和服务器均进行训练而得到的。这种采用联邦学习架构,基于边缘设备和服务器之间的配合或联合而训练的模型通常被称之为联邦学习模型。理想或期望的联邦学习模型可通过边缘设备和服务器之间的配合或联合训练而得到。
考虑到,通常边缘设备采用的本地数据集是由自身收集而得到的,不对外共享,各边缘设备采用本地数据集对待训练的机器学习模型进行本地训练,实现了对本地数据集的隐私性保护,且为不泄露本地数据集,即可实现模型训练提供了一种技术支持。
考虑到前述的Client/Server架构的优势,本公开基于Client/Server架构而提出的一种对联邦学习模型进行(联合)训练的方法。这种训练方法,考虑到了动量这个参数,动量可在一定程度上体现模型的梯度下降情况,基于动量进行模型的训练,不仅可使模型梯度快速下降,实现模型的快速收敛、缩短训练时长,还可使模型训练得更加精准,为训练更精准的联邦学习模型提供了一种技术支持。且保证联邦学习模型的高使用效果,如识别准确性高。
下面对本公开的训练方法进行进一步说明。
本公开的机器学习模型的训练方法,应用于第一设备中。第一设备为图1所示的一个或多个边缘设备。如图2所示,所述方法包括:
S201:接收对全局模型的更新指令,所述指令用于触发本地模型的训练;
本步骤中,全局模型为需要边缘设备和服务器之间进行联合训练的联邦学习模型。全局模型可以是采用任何合理算法的机器学习模型,如卷积神经网络(CNN)模型、深度学习(VGG)模型。
更新指令为第二设备向被选中的边缘设备发送的需要更新或训练联邦学习模型的指令。在被选中的边缘设备接收到第二设备发送的该指令的情况下,执行本地更新流程,本地更新流程即对本地模型进行更新以对本地模型进行训练的流程。
其中,被选中的边缘设备是Client/Server架构中的全部边缘设备中的至少部分边缘设备。通常,被选中的边缘设备的数量大于或等于两个。第二设备为图1所示的服务器。
可以理解,在联合训练时,由于服务器侧和边缘设备侧均需要对各自的模型进行训练。可以理解,在联合训练中,联合训练的模型是同一模型,即,边缘设备和服务器联合训练的联邦学习模型为同一模型,如同一CNN模型或同一VGG模型。为区分,将服务器侧的该同一模型称之为全局模型,全局模型由服务器来进行训练。将边缘设备侧的该同一模型称之为本地模型,本地模型由边缘设备来训练。
同一机器学习模型被两种不同类型的设备(边缘设备和服务器)训练,可能训练出的模型参数不同。在联合训练中,边缘设备先对该同一模型进行训练,得到边缘设备训练出的模型参数,然后服务器对边缘设备训练出的模型参数进行进一步训练,以得到更加精准的模型参数,即,得到更加精准的联邦学习模型。
可以理解,通常模型参数包括两种类型参数:一种是与模型相关的参数,另一种是与模型调优、训练相关的参数。其中,与模型相关的参数包括模型的权重参数(w)、或权重参数(w)和偏置参数(b)。与模型调优、训练有关的参数目的是让模型训练的效果更好、收敛速度更快,这些调优参数通常被称为超参数(Hyper Parameters)。超参数选择的目标是:保证模型在训练阶段既不会拟合失败,也不会过度拟合,同时让模型尽可能快地学习数据结构特征。在实际应用中,常用的超参数通常包括学习率(learning rate)、迭代次数、激活函数等。
可以理解,对模型的训练意在对模型进行如上模型参数的更新或迭代,使模型更加准确。
S202:响应更新指令,基于本地数据集,得到本地模型的本地动量参数;
接收到更新指令的边缘设备响应更新指令,对本地模型进行训练。进一步的,每个边缘设备均具有本地数据集,本地数据集可作为本地训练的数据样本而使用。每个边缘设备的本地数据集可由自身设备对数据的收集而得到。不同边缘设备之间的本地数据集不共享,也无需共享给服务器。每个接收到更新指令的边缘设备基于自身收集的本地数据集,计算本地模型的本地动量参数。其中,动量参数的计算方法请参见后续相关说明。
示例性地,每个边缘设备收集到的本地数据集可以联邦学习通过横向学习而得到的数据,如每个本地数据集是在各消费场所如商场、超市、医院中出现的各消费者的人脸图像。
S203:根据本地动量参数对本地模型进行训练,得到本地模型的本地目标参数;
本步骤中,本地目标参数可视为对本地模型进行本地训练而得到的模型参数。本地动量参数是与本地模型的模型梯度相关的参数,利用与模型梯度相关的参数对本地模型进行训练,可得到精准的模型参数、且模型收敛得快、大大缩短边缘设备的本地训练时长。
S204:发送本地目标参数,所述本地目标参数用于供第二设备对全局模型进行参数更新,完成参数更新的全局模型用于识别待识别数据中的目标对象。
本步骤中,第一设备将对本地模型进行训练而得到的模型参数发送至第二设备。第二设备根据第一设备发送的模型参数对全局模型进行参数更新,即对全局模型进行进一步训练,以得到期望的全局模型。
S201~S204中,第一设备响应接收到的针对全局模型的更新指令,执行本地模型的训练流程。在第一设备的本地训练过程中,基于本地模型的本地动量参数实现对本地模型的训练。其中,动量是可在一定程度上体现模型的梯度下降速度的参数,基于动量参数实现对本地模型的训练,不仅可使模型梯度快速下降,实现模型的快速收敛、缩短训练时长,还可使模型训练得更加精准,即本地目标参数更加精准。本地目标参数更加精准,可得到更加精准的全局模型,由此可训练出精准的联邦学习模型,提高联邦学习模型的使用效果,如提高识别准确性。可见,本公开为精准训练联邦学习模型提供了一种技术支持。
此外,第一设备实现了对本地数据集的隐私性保护,基于动量参数实现对本地模型的训练的方案,为本地模型的训练提供了一种新的技术支持。
在本公开的联合训练中,可将第二设备发送的更新指令视为触发第一设备进行本地模型的一次训练的事件。一次触发事件的产生,可令第一设备执行一次本地模型的(本地)训练流程。在一次本地训练流程中,第一设备可对本地模型的模型参数进行一次更新或迭代(单次训练),即针对接收到的某个更新指令,执行一次S202~S203的方案,或者进行多次更新或迭代(多次训练),即,针对接收到的某个更新指令,执行多次S202~S203的方案,以实现对本地模型的训练。
在本公开的联合训练中,第二设备存在有对全局模型的模型参数进行多轮迭代或多轮更新的需求。在存在每一轮迭代或更新需求的情况下,第二设备产生一次对全局模型的更新指令,并发送至该更新指令至该轮被选中的边缘设备。该轮被选中的边缘设备在接收到更新指令的情况下,按照各自的本地训练流程进行各自本地模型的训练。该轮被选中的边缘设备将训练出的各自的本地目标参数发送至第二设备。第二设备对该轮接收到的各被选中的边缘设备训练出的本地目标参数进行聚合,根据聚合结果进行全局模型的模型参数的该轮迭代或更新。对全局模型的每一轮迭代或更新参见如上流程,多次执行以上联合训练流程,直至第二设备训练出期望的全局模型。
第二设备在存在第t轮对全局模型进行迭代或更新需求情况下产生更新指令。该更新指令可视为第t轮更新指令。第一设备接收第二设备发送的该第t轮更新指令,并响应第t轮更新指令,对本地模型的模型参数进行一次本地训练流程,一次本地训练流程包括对本地模型的模型参数的多次更新或迭代,即对本地模型的模型参数进行多次训练。关于本地多次迭代的过程可以参见对图3所示的方案来理解。
在实际应用中,本申请实施例中由第二设备基于第一设备得到的本地目标参数来实现参数更新的全局模型可用于识别待识别图像中的目标对象,如识别人脸图像中的人脸、识别动物图片中的动物。即,利用完成参数更新的全局模型可实现对待识别图像中的目标对象的智能识别,为智能识别提供了一种技术支持。此外,全局模型是基于第一设备和第二设备的联合训练而训练出的,且在第一设备侧是基于动量实现的对全局模型需要的本地模型的本地目标参数的精准计算。本地目标参数的精准计算以及第一设备和第二设备的联合训练,可大大提高全局模型的训练精准性,由此可令被训练出的联邦学习模型实现精准的智能识别,提高联邦学习模型的使用(识别)效果。
基于此,第一设备的本地数据集可以是训练图像集,如人脸图像集、动物图片集。针对第二设备产生的第t轮更新指令,第一设备可基于本地的训练图像集,实现对本地模型的模型参数的本地训练。
在本地数据集包括多个训练图像集的情况下,如图3所示,前述的基于本地数据集得到本地模型的本地动量参数的方案通过S301和S302的方案来实现。
S301:基于每次训练下采用的训练图像集,得到每次训练下本地模型的图像梯度,将每次训练下的本地模型的图像梯度作为每次训练下的本地模型的本地梯度参数;
本步骤中,假定第一设备的一次本地训练包括对本地模型的T’次迭代或更新。T’为大于等于1的正整数,为一次本地训练流程中的总训练次数。
第一设备响应第二设备产生的第t轮更新指令,从本地数据集中,读取部分人脸图像作为第t’个训练图像集。其中,本地数据集可视为人脸图像的集合。其中,t’为大于等于1且小于等于T’的正整数。
以本地模型为CNN模型为例,将第t’个训练图像集输入至CNN模型,以令CNN模型进行人脸特征的学习。可以理解,CNN模型中包括卷积层。针对输入的第t’个训练图像集,卷积层可对第t’个训练图像集中的人脸图像进行特征提取。提取出的特征包括低层特征和高层特征。其中,低层特征包括图像中的各区域的颜色、灰度、边缘、纹理和形状;高层特征包括图像所表达的语义,如图像是表示人脸的图像。基于低层特征和高层特征,第一设备采用导数求取方法,计算第t’个训练图像集的图像梯度。其中,采用导数求取方法计算图像梯度的方法、以及图像梯度的说明请参见相关描述,不赘述。
第一设备将第t’个训练图像集的图像梯度作为第一设备在响应第t轮更新指令的本地训练流程中、对本地模型进行第t’次训练而得到的本地梯度参数。
S302:基于每次训练下本地模型的本地梯度参数,对每次训练对应的参考动量参数进行调整,得到每次训练下本地模型的本地动量参数。
本步骤中,根据公式(1)计算第t次训练下的本地动量参数。
Figure BDA0003898658260000091
Figure BDA0003898658260000092
表示被选中的第k个第一设备响应第二设备产生的第t轮更新指令,在一次本地训练流程对本地模型进行第t’次训练而得到的本地动量参数。g'k()表示被选中的第k个第一设备在一次本地训练流程中对本地模型进行第t’次训练而得到的本地梯度参数。在本公开中,g'k()是带有本地模型的权重参数wk的梯度参数,所以本地梯度参数写成了
Figure BDA0003898658260000093
Figure BDA0003898658260000094
为被选中的第k个第一设备响应第二设备产生的第t轮更新指令在一次本地训练流程中对本地模型进行第t’-1次训练而得到的权重参数wk,可作为第t次训练下的参考动量参数使用。
Figure BDA0003898658260000095
表示被选中的第k个第一设备响应第二设备产生的第t轮更新指令,在一次本地训练流程的第t’-1次训练下而得到本地动量参数。β′表示为梯度配置的权重。
在公式(1)中,在
Figure BDA0003898658260000096
β′、
Figure BDA0003898658260000097
均已知的情况下,可得到
Figure BDA0003898658260000098
从公式(1)中可看出,在本地训练流程中,每次训练下得到的本地动量参数均是基于本地梯度参数而得到的。
基于本地梯度参数实现对动量参数的调整或计算,可保证动量参数的计算准确性。动量参数的计算准确性,可保证对本地模型的模型参数(如本地目标参数)的计算准确性。由此可提高第二设备对全局模型的训练准确性。
相应的,前述的根据本地动量参数对所述本地模型进行训练,得到本地模型的本地目标参数的方案通过S303和S304的方案来实现。
S303:基于每次训练下本地模型的本地动量参数,对每次训练下本地模型的初始模型参数进行调整,得到每次训练下本地模型的候选参数;
本步骤中,根据公式(2)进行第t次训练下的本地模型的候选参数
Figure BDA0003898658260000101
的计算。
Figure BDA0003898658260000102
其中,
Figure BDA0003898658260000103
表示被选中的第k个第一设备响应第二设备产生的第t轮更新指令,在一次本地训练流程的第t’-1次训练下而得到的权重参数,作为第t次训练下本地模型的初始模型参数而使用。其中,k为大于等于1且小于等于N的正整数。η′是本地模型的学习率。
Figure BDA0003898658260000104
表示被选中的第k个第一设备响应第二设备产生的第t轮更新指令,在一次本地训练流程的对本地模型进行第t’次训练而得到的权重参数,作为候选参数而使用。
在公式(2)中,在
Figure BDA0003898658260000105
均已知的情况下,可得到
Figure BDA0003898658260000106
需要说明的是,在一次本地训练流程中,在按照公式(2)计算出
Figure BDA0003898658260000107
的情况下,如果t’大于等于T’,则t’=t’+1,重复执行S301~S303所示的方案,直至t’=T’。
在重复执行S301~S303过程中,可以看出,公式(1)会被多次更新或迭代,在联邦学习架构中,这种动量的迭代可被称之为动量的自适应更新。在本公开的本地模型训练方案中,采用动量的自适应更新方案来实现对本地模型的模型参数的更新或迭代。这种自适应更新方案,可使本地模型的下降梯度快,缩短训练流程,保证本地模型的本地目标参数的计算精准性。
在T’次训练过程中,如果t’≠1,则第t’次训练次数下本地模型的初始模型参数为第t’-1次训练下而得到的权重参数。即,在多次训练的非首次训练下,本地模型的初始模型参数为非首次训练的前一次训练下而得到的本地模型的候选参数。
第二设备向第一设备发送的第t轮更新指令中还携带有全局模型在第t轮需要更新的模型权重。如果t’=1,则第t’次训练次数下本地模型的初始模型参数是第二设备在更新指令中携带的全局模型的模型参数。即,在多次训练的首次训练下,本地模型的初始模型参数基于更新指令而得到。
也就是说,在第二设备产生对全局模型进行第t轮更新的更新指令的情况下,被选中的第一设备可使用更新指令中携带的全局模型的待更新权重参数进行本地模型的模型参数的多次迭代或计算,并将训练出的本地模型的模型参数如模型权重返回至第一设备。
这种基于t’的取值是否为1的结果确定相应次训练下采用的初始模型参数。基于相应次训练下的初始模型参数进行本地模型的相应次迭代或更新,以实现对本地模型的训练,可实现对本地模型的模型参数(如候选参数、本地目标参数)的准确计算。
S304:基于每次训练下本地模型的候选参数,得到本地模型的本地目标参数。
可以理解,本公开中,本地模型的训练意在采用公式(1)和公式(2),对本地模型的权重参数w′k进行多次更新或迭代。本地训练流程中,最后一次迭代得出的权重参数
Figure BDA0003898658260000111
即可作为第一设备对第二设备产生的第t轮更新指令进行响应而得到的模型参数。该模型参数可作为第一设备对第二设备产生的第t轮更新指令进行响应而得到的本地模型的本地目标参数。
也就是说,在本公开中,本地目标参数可以为最后一次训练次数下而得到的候选参数。即,本地目标参数为
Figure BDA0003898658260000112
当然,本地目标参数也可以为每次训练而得到的本次模型的权重参数的算术平均值或加权平均值,本地目标参数还可以为多次训练而得到的权重参数中的其中之一。优选的,本地目标参数为多次训练中的最后一次训练下而得到的候选参数。本地目标参数通过本地的多次迭代或更新而得到,可保证本地目标参数的精准性,由此可提高第二设备对全局模型的训练精准性。
从本地模型的角度来看,动量在一定程度上体现了本地模型的梯度下降情况如梯度下降方向和/或下降速度,基于本地动量参数对本地模型进行训练,可使本地模型快速收敛,加快完成本地训练流程,缩短本地训练时长。基于动量参数训练出的本地模型的模型参数(本地目标参数)更加精准,由此也可大大提高第二设备对全局模型的训练准确性。
作为对第二设备的第t轮更新指令的响应结果,被选中的第k个第一设备将经过多次迭代或更新而得到的本地目标参数发送至第二设备,以供第二设备进行全局模型的更新。
其中,被选中的各第一设备针对某一轮的更新指令向第二设备反馈的信息为第一设备在本地训练流程中最后一次训练次数下而得到的候选参数,无需反馈其他信息。反馈的信息较少,可有效减少第一设备和第二设备之间的通信开销。
通常,本地模型的初始模型参数、候选参数、本地目标参数、本地模型的模型参数主要指的是本地模型的权重参数。
在一些可选实施例中,完成参数更新的全局模型通过采用全局动量参数对全局模型的全局参考参数进行调整而得到;其中全局动量参数基于至少一个第一设备发送的本地目标参数而得到的。
在第二设备侧,基于全局动量参数对全局模型进行训练。全局动量参数在一定程度上体现(全局)模型的梯度下降速度,基于动量对模型进行训练,不仅可使模型梯度快速下降,实现模型的快速收敛、缩短训练时长,还可使模型训练得更加精准。
其中,全局动量参数可通过任何合理方式的而得到,如通过动量计算方法而得到。在一些可选实施例中,全局动量参数可基于全局参考动量和全局梯度而得到。其中全局梯度基于至少一个第一设备的本地目标参数之间的聚合结果而得到。这种基于全局参考动量和全局梯度得到全局动量参数的方案,可保证全局动量参数的计算准确性,实现对全局模型的精准训练,由此可得到更加精准的联邦学习模型。
以上关于第二设备对全局模型的训练过程请参见后续相关说明。
本公开中,存在有对全局模型的每一轮的更新需求的情况下,第二设备从所有可选第一设备中进行各轮中需要对本地模型进行训练的第一设备的选择。每个第一设备按照图2或图3所示的方案执行相应轮下的本地模型的训练流程。在训练流程结束的情况下,同一轮下各被选中的第一设备将自身训练出的本地目标参数发送至第二设备,以供第二设备采用针对该轮下的更新指令而反馈回的各第一设备的本地目标参数对全局模型进行该轮的训练。下面对第二设备对全局模型进行训练的过程进行说明。
本公开实施例中,提供一种应用于第二设备的机器学习模型训练方法。第二设备可以为图1所示的服务器。如图4所示,所述方法包括:
S401:接收针对全局模型的更新指令而产生的多个本地目标参数,所述多个本地目标参数为各第一设备根据前述应用在第一设备中的训练方法而得到;
本步骤中,第二设备在存在有对全局模型进行更新需求的情况下,产生更新指令,并发送到各被选中的边缘设备。各被选中的边缘设备按照前述应用于第一设备中的训练方法进行本地模型的训练,得到本地模型的本地目标参数。第二设备接收各被选中的边缘设备针对更新指令而反馈回的各本地目标参数。
S402:基于对多个本地目标参数的聚合结果,得到全局模型的全局动量参数;
本步骤中,采用预设聚合算法如联邦聚合算法对多个本地目标参数进行聚合,得到对多个本地目标参数的聚合结果。其中,联邦聚合算法包括但不限定于:FedAvg、FedProx、SCAFFOLD。基于联邦聚合算法进行聚合可保证聚合结果的准确性,进而保证全局动量参数的计算准确性。
本步骤中,可采用动量计算方法基于聚合结果进行全局动量参数的计算。
S403:采用全局动量参数,对所述全局模型进行参数更新,得到目标全局模型,所述目标全局模型用于识别待识别数据中的目标对象。
目标全局模型可认为是第二设备训练出的期望联邦学习模型。S402和S403可视为第二设备根据接收到的各本地目标参数对全局模型进行训练的方案。
S401~S403中,一方面,第一设备侧的本地模型的本地目标参数的精准性,可保证对全局模型的训练精准性。另一方面,第二设备采用(全局)动量参数对(全局)模型进行训练,动量可使模型梯度下降快,实现对模型的快速收敛,不仅使得模型更加精准,还可有效缩短时长。联邦学习模型的精准训练,可有效提高学习模型的使用效果,如提高识别准确性。
在第二设备对全局模型进行训练的方案中,可对全局模型进行单轮或多轮训练,以得到期望的联邦学习模型。在每一轮中,第二设备对全局模型的迭代或更新可以是一次,可以是多次。
下面以第二设备对全局模型进行第t轮更新、第t轮更新中迭代一次为例进行第二设备根据接收到的各本地目标参数对全局模型进行训练的方案的说明。
在一些可选实施例中,前述的基于对所述多个本地目标参数的聚合结果,得到全局模型的全局动量参数的方案可通过图5所示的S501和S502来实现。
S501:基于所述多个本地目标参数的聚合结果,得到全局模型的全局梯度;
本步骤中,可采用公式(3)对接收到的各本地目标参数进行聚合。
Figure BDA0003898658260000141
其中,
Figure BDA0003898658260000142
表示在第t轮下第二设备对接收到的各地本地目标参数进行聚合得到的聚合结果。
Figure BDA0003898658260000143
表示第k个第一设备针对第二设备在第t轮训练下产生的第t个更新指令而向第二设备反馈回的本地目标参数。
Figure BDA0003898658260000144
表示第二设备的第t轮训练下被选中的第一设备的集合。nk表示第t轮训练下被选中的第k个第一设备的训练样本数量。n'表示第t轮训练下被选中的所有第一设备的训练样本数量的总和。
接下来,采用(4)对聚合结果进行处理,得到全局模型的全局梯度。
Figure BDA0003898658260000145
其中,g(wt-1)表示在第t轮下对聚合结果进行处理得到的全局模型的全局梯度。η为全局模型的超参数,表示全局模型的学习率。
wt(0)表示在第二设备需要有对全局模型的第t轮更新需求情况时全局模型的模型参数,该模型参数是需要经过第t轮下第一设备的本地训练流程和第二设备的更新流程进行更新的。该模型参数需要通过携带在更新指令中由第二设备发送至各被选中的第一设备。第一设备将本地模型的模型参数取值为更新指令中的模型参数,并进行本地训练流程。然后,第二设备根据接收到在第一设备的本地训练中得到的模型参数进行服务器侧的对全局模型的第t轮更新方案。相当于,在第t轮中,需要第一设备利用本地训练流程对wt(0)进行更新,得到本地模型的目标模型参数。第二设备利用各第一设备训练出的目标模型参数进行继续更新,得到第t轮中对全局模型的模型参数的更新结果。t=t+1,按照第t轮的过程继续进行第t+1轮的训练方案,直至服务器得到期望的全局模型的模型参数。
其中,
Figure BDA0003898658260000151
其中,τ是第t轮下的总迭代次数。为方便起见,可以取τ为1。当然,τ可以根据实际情况取任何合理的取值。i表示第t轮下的第几次迭代。在τ=1的情况下,i取值为1。
Figure BDA0003898658260000152
是第t轮下服务器对聚合结果
Figure BDA0003898658260000153
进行了第i次迭代而得到的参数。
Figure BDA0003898658260000154
()是服务器侧的随机梯度。
在实际应用中,
Figure BDA0003898658260000155
其中E表示在某一个轮中在第一设备的本地模型训练中需要迭代的次数。B表示批大小,通常为固定值。n0表示服务器侧使用的训练样本数量。
Figure BDA0003898658260000156
其中,acct-1是将服务器的训练样本输入至权重参数取值为
Figure BDA0003898658260000157
的全局模型下得到的识别准确率。以训练样本包括多个人脸图像为例,以人脸图像作为权重参数取值为
Figure BDA0003898658260000158
的全局模型的输入,计算能识别人脸区域的人脸图像的数量,计算该数量占输入图像总数量的百分比,将此百分比作为识别准确率。其中,
Figure BDA0003898658260000159
是在第t-1轮下各被选中的第一设备反馈的本地目标参数进行聚合后得到的模型参数。
Figure BDA00038986582600001510
表示第t-1轮被选中的第一设备的训练样本的总体数据分布。P0表示服务器侧的训练数据的分布。decay∈(0,1),为固定值,用于保证模型的收敛。
Figure BDA0003898658260000161
是超参数。f′(acc)是关于acc的函数,通常采用的是f′(acc)=1-acc。在本公开中,在对全局模型进行初期训练(如第t=1轮训练)时,取acc的值较小,相应的f′(acc)的值就比较大,从而更多地利用服务器的训练样本更新模型。随着更新轮数的增加,在训练后期,f′(acc)的值会变小,如此,便减少了服务器的训练样本对全局模型的训练的影响。
这里,需要说明的是,在服务器对全局模型的第t=1轮更新方案中,在第t=1个更新指令中携带并向各被选中的边缘设备下发的全局模型的模型参数是通过采用服务器侧的训练样本对全局模型进行训练而得到的模型参数。服务器侧的训练样本可以是无需进行隐私保护的样本,如无需进行隐私保护的人脸图像。在后续轮的更新中,每一轮中向边缘设备下发的全局模型的模型参数可以是服务器在上一轮中对全局模型进行训练而得到的模型参数。即,随着对全局模型的更新轮数的增加,服务器侧的训练样本对模型参数的影响越来越小。
考虑到通常服务器侧的训练样本的数量较少,与依赖训练样本相比,本公开中的对全局模型的每轮更新更加依赖于前一轮的模型参数的更新结果,不仅减少了对训练样本的依赖,避免因训练样本不足而导致的模型训练不准确的问题,还实现了对模型参数的准确迭代。可见,本公开为服务器侧训练样本少但能训练出准确的联邦学习模型提供了一种技术支持。
S502:基于全局梯度g(wt-1)和全局参考动量,得到全局模型的全局动量参数。
本步骤中,根据公式(7)进行全局动量参数的计算。
mt=β*mt-1+(1-β)*g(wt-1) (7)
其中,mt为第t轮的全局动量参数。mt-1为第t-1轮的全局动量参数,作为全局参考动量使用。β是为全局模型的梯度配置的权重。
此处,公式(7)可认为是采用第t轮的全局梯度,对第t-1轮的全局动量参数进行调整以得到第t轮的全局动量参数。
可以理解,如果服务器对全局模型进行多轮更新,每轮更新中服务器对全局模型进行L次迭代就采用公式(7)L次对相应轮下的全局动量进行L次更新。其中,L为大于等于1的正整数。
在联邦学习架构中,这种对全局动量的更新可称之为动量的自适应更新方案。在本公开的服务端训练(服务器对全局模型进行训练)方案中,采用动量的自适应更新方案来实现对全局模型的模型参数的更新或迭代。这种方案,可使全局模型的下降梯度快,缩短服务端对全局模型的训练流程,提高全局模型训练精准性,提高联邦学习模型的使用(如识别)效果。
相应的,前述的采用全局动量参数,对所述全局模型进行参数更新的方案可通过图5所示的S503和S504来实现。
S503:采用全局动量参数,对全局模型的全局参考参数进行调整,得到全局目标参数;
本步骤中,采用公式(8)进行全局目标参数的计算。
wt=wt-1-η*mt (8)
其中,mt为第t轮的全局动量参数;wt-1为第t-1轮对全局模型的迭代或更新而产生的模型参数,作为全局参考参数使用。η为全局模型的学习率。
从公式(8)可以看出,服务器可对全局模型进行多轮迭代或训练。在一轮迭代中,服务器均可采用公式(3)对该轮下各被选中的边缘设备上报的本地目标参数进行聚合,并根据公式(7)和公式(8)实现模型参数的一轮迭代。下一轮迭代时向边缘设备发送的更新指令中携带的模型参数是该下一轮迭代之前的一轮采用公式(7)和公式(8)得出的模型参数-全局目标参数,可得到更加精准的模型参数。如此多轮即可实现服务器对全局模型的精准训练。
S504:基于全局目标参数,得到目标全局模型。
本步骤中,可将模型参数取值为任意轮训练出的全局目标参数的全局模型作为目标全局模型。优选的,可将模型参数取值为最后一轮训练出的全局目标模型的全局模型作为目标全局模型。
目标全局模型可以为服务器期望训练出的全局模型。期望的全局模型可经过对全局模型进行多轮训练而得到。经过多轮且基于全局动量对全局模型进行训练的方案,可保证联邦学习模型的训练精准性。精准的联邦学习模型可大大发挥其使用效果,示例性地,对人脸图像中的人脸区域识别得更加精准。
通常情况下,前述的全局参考参数、全局目标参数、全局模型的模型参数主要是指全局模型的权重参数。
本公开中,理想或期望的联邦学习模型是通过边缘设备和服务器之间的配合或联合训练而得到。本公开中联合训练的主要原理是:
针对服务器产生的对全局模型的第t轮更新需求,产生第t个更新指令,发送第t个更新指令至在第t轮中被选中的各边缘设备。各边缘设备对接收到的第t个更新指令进行响应,采用如图2或图3所示的本地训练流程对本地模型进行多次更新,并将多次更新后的本地模型的模型参数作为对第t个更新指令的反馈信息,反馈至服务器。服务器接收到反馈信息,按照图5所示的方案对全局模型进行第t轮更新。如果t不等于T,则t=t+1,继续执行以上方案,直至t等于T训练出期望的联邦学习模型。其中,T为服务器对全局模型进行训练的最大轮数。
其中,第t轮更新指令中携带有该轮下第二设备向边缘设备提供的全局模型的待训练模型参数,各边缘设备接收到更新指令的情况下,在本地训练流程中,将更新指令中携带的待训练模型参数作为本地模型的迭代前的模型参数,按照该模型参数执行本地的多次迭代或更新,得到经本地多次更新或迭代后的模型参数,将本地多次更新或迭代后的模型参数作为对第t轮更新指令的反馈信息或响应结果,反馈至服务器。服务器先对接收到的反馈信息进行聚合,基于聚合结果对全局模型进行第t轮的更新。在服务器每产生一轮想要对全局模型进行更新的更新指令的情况下,就按照前述方案执行一轮对全局模型的更新或训练方案。
通常情况下,需要对全局模型进行T轮更新才能训练出期望的联邦学习模型。经过不仅一轮得到期望的联邦学习模型的方案,为一种对联邦学习模型进行不断动态优化的方案。这种方案可保证联邦学习模型的训练精准性。其中,t为大于等于1且小于等于T的正整数。
从前述的对本地模型的本地动量和全局模型的全局动量的计算公式中可看出,本地动量和全局动量是动态优化的参数,基于动态优化的参数对本地模型或全局模型进行训练,可保证联邦学习模型的训练精准性。
可以理解,在联合训练中,边缘设备执行的对本地模型的训练,服务器执行的对全局模型的训练,均是基于动量进行的训练,不仅可使模型梯度快速下降,实现模型的快速收敛、缩短训练时长,还可使模型训练得更加精准。由此可见,本公开提供的基于动量进行联邦学习模型的联合训练方案可提高联邦学习模型的训练精准性,由此可提高联邦学习模型的使用效果。
在基于联合训练原理训练出期望的全局模型之后,可使用训练出的期望全局模型对待识别数据中的目标对象进行识别。基于此,本公开提供一种识别方法,包括:获得待识别数据;将所述待识别数据输入至前述的目标全局模型,得到所述待识别数据中的目标对象。
在实施时,获得待识别数据,将待识别数据输入至目标全局模型,得到目标全局模型输出的待识别数据中的目标对象。
示例性地,采集人脸图像,将采集到的人脸图像作为待识别数据,将采集到的人脸图像输入至目标全局模型,得到目标全局模型输出的人脸图像中的目标对象-人脸区域。可以理解,由于目标全局模型是基于动量进行联邦学习或联合训练而得到的,所以目标全局模型具有精准性,可提高对待识别数据中的目标对象的识别精准性。
前述方案中,是以第一设备的本地数据集、服务器的训练样本和待识别数据为图像如人脸图像为例,除此之外,本公开中的本地数据集、服务器的训练样本和待识别数据可以为语音、文本等。
在待识别数据为语音数据时,可将待识别的语音数据输入至目标全局模型,得到目标全局模型输出的待识别语音数据中的敏感数据或骚扰数据。敏感数据或骚扰数据可作为待识别语音数据中的目标对象使用。其中,敏感数据可以是诸如身份证号、医保卡号等具有隐私性的数据。骚扰数据可以是诸如辱骂、暴力语言等数据。
在待识别数据为文本数据时,可将待识别的文本数据输入至目标全局模型,得到目标全局模型输出的待识别文本数据中的特定数据。其中,特定数据可以是预设的任何合理的数据,如前述的以文字形式记载的隐私性数据或骚扰数据。诸如身份证号、医保卡号等具有隐私性的数据。
由于目标全局模型的精准性,所以采用目标全局模型不论是识别待识别语音数据中的敏感数据或骚扰数据,还是识别待识别文本数据中的特定数据,均可提高识别准确性,达到对待识别数据中的目标对象的精准识别,令目标全局模型具有很好的使用效果。
本公开提供一种机器学习模型训练装置,如图6所示,所述装置包括:
接收单元601,用于接收对全局模型的更新指令,所述指令用于触发本地模型的训练;
第一获得单元602,用于响应所述更新指令,基于本地数据集,得到所述本地模型的本地动量参数;
第二获得单元603,用于根据所述本地动量参数对所述本地模型进行训练,得到所述本地模型的本地目标参数;
发送单元604,用于发送所述本地目标参数,所述本地目标参数用于供第二设备对所述全局模型进行参数更新,完成参数更新的全局模型用于识别待识别数据中的目标对象。
在一些实施例中,所述本地数据集包括多个训练图像集,所述第一获得单元602,用于:
基于每次训练下采用的训练图像集,得到每次训练下所述本地模型的图像梯度,将所述图像梯度作为本地梯度参数;
基于每次训练下所述本地模型的本地梯度参数,对每次训练对应的参考动量参数进行调整,得到每次训练下所述本地模型的本地动量参数。
在一些实施例中,所述第二获得单元603,用于:
基于每次训练下采用的训练图像集,得到每次训练下所述本地模型的图像梯度,将所述图像梯度作为本地梯度参数;
基于每次训练下所述本地模型的本地梯度参数,对每次训练对应的参考动量参数进行调整,得到每次训练下所述本地模型的本地动量参数。
在一些实施例中,所述本地目标参数为多次训练中的最后一次训练下而得到的候选参数。
在一些实施例中,在多次训练的非首次训练下,所述本地模型的初始模型参数为所述非首次训练的前一次训练下而得到的所述本地模型的候选参数;在多次训练的首次训练下,所述本地模型的初始模型参数基于所述更新指令而得到。
在一些实施例中,所述完成参数更新的全局模型用于识别待识别图像中的目标对象。
在一些实施例中,所述完成参数更新的全局模型通过采用全局动量参数对所述全局模型的全局参考参数进行调整而得到;
其中,所述全局动量参数基于至少一个所述第一设备发送的本地目标参数而得到的。
在一些实施例中,所述全局动量参数基于全局参考动量和全局梯度而得到,其中所述全局梯度基于所述至少一个所述第一设备的本地目标参数之间的聚合结果而得到。
本公开提供另一种机器学习模型训练装置,如图7所示,包括:
接收模块701,用于接收针对全局模型的更新指令而产生的多个本地目标参数;
第一获得模块702,用于基于对所述多个本地目标参数的聚合结果,得到全局模型的全局动量参数;
第二获得模块703,用于采用全局动量参数,对所述全局模型进行参数更新,得到目标全局模型,所述目标全局模型用于识别待识别数据中的目标对象。
在一些实施例中,所述聚合结果通过采用预设聚合算法对所述多个本地目标参数进行聚合而得到。
在一些实施例中,所述第一获得模块702,用于:
基于所述多个本地目标参数的聚合结果,得到所述全局模型的全局梯度;
基于所述全局梯度和全局参考动量,得到全局模型的全局动量参数。
在一些实施例中,所述第二获得模块703,用于:
采用全局动量参数,对全局模型的全局参考参数进行调整,得到全局目标参数;
基于所述全局目标参数,得到目标全局模型。
需要说明的是,图6所示的装置可位于图1所示的边缘设备中或为边缘设备本身。图7所示的装置可位于图2所示的服务器中或为服务器本身。
本公开提供一种识别装置,如图8所示,所述装置包括:
第一获得单元801,用于获得待识别数据;
第二获得单元802,用于将所述待识别数据输入前述的目标全局模型,得到所述待识别数据中的目标对象。
本公开实施例的两种机器学习模型训练装置和识别装置中的各组成单元的功能可以参见相关方法的描述,在此不再赘述。本公开实施例的前述装置,由于解决问题的原理与前述的相关方法相似,因此,装置的实施过程及实施原理、有益效果均可以参见前述相关方法的实施过程及实施原理、有益效果描述,重复之处不再赘述。
根据本公开的实施例,本公开还提供了一种电子设备,所述电子设备包括至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行前述的机器学习模型训练方法和/或识别方法。
关于电子设备的处理器、存储器的描述可参见图9中计算单元901、存储单元908的相关说明。
根据本公开的实施例,本公开还提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使计算机执行前述的机器学习模型训练方法和/或识别方法。关于计算机可读存储介质的说明请参见图9中的相关说明。
根据本公开的实施例,本公开还提供了一种计算机程序产品,包括计算机程序,该计算机程序在被处理器执行时实现前述的机器学习模型训练方法和/或识别方法。关于计算机程序产品的说明请参见图9中的相关说明。
本公开的技术方案中,所涉及的用户个人信息(如人脸图像)的获取,存储和应用等,均符合相关法律法规的规定,且不违背公序良俗。
图9是用来实现本公开实施例的电子设备的框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或要求的本公开的实现。
如图9所示,电子设备900包括计算单元901,其可以根据存储在ROM 902中的计算机程序或者从存储单元908加载到RAM 903中的计算机程序来执行各种适当的动作和处理。在RAM 903中,还可存储电子设备900操作所需的各种程序和数据。计算单元901、ROM 902以及RAM 903通过总线904彼此相连。输入输出(I/O)接口905也连接至总线904。
电子设备900中的多个部件连接至I/O接口905,包括:输入单元906,例如键盘、鼠标等;输出单元907,例如各种类型的显示器、扬声器等;存储单元908,例如磁盘、光盘等任何可作为存储器使用的器件;以及通信单元909,例如网卡、调制解调器、无线通信收发机等。通信单元909允许电子设备900通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
本公开实施例中的存储单元908可以具体为便携式计算机盘、硬盘、随机存储器(RAM)、只读存储器(ROM)、可擦编程只读存储器(EPROM)或快闪存储器、光纤、CD-ROM、光学储存设备、磁储存设备中的至少一种存储器。
计算单元901可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元901的一些示例包括但不限于CPU、图形处理单元(GPU)、人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字处理器(DSP)、以及任何适当的处理器、控制器、微控制器等处理器。计算单元901执行上文所描述的各个方法和处理,例如机器学习模型训练方法和/或识别方法。例如,在一些实施例中,机器学习模型训练方法和/或识别方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元908。在一些实施例中,计算机程序的部分或者全部可以经由ROM 902和/或通信单元909而被载入和/或安装到电子设备900上。当计算机程序加载到RAM 903并由计算单元901执行时,可以执行上文描述的机器学习模型训练方法和/或识别方法的一个或多个步骤。备选地,在其他实施例中,计算单元901可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行机器学习模型训练方法和/或识别方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、可编辑阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器等可执行计算机程序代码的产品,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质(存储介质)可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、RAM、ROM、EPROM或快闪存储器、光纤、CD-ROM、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入、或者触觉输入来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。

Claims (29)

1.一种机器学习模型训练方法,应用于第一设备中,所述方法包括:
接收对全局模型的更新指令,所述指令用于触发本地模型的训练;
响应所述更新指令,基于本地数据集,得到所述本地模型的本地动量参数;
根据所述本地动量参数对所述本地模型进行训练,得到所述本地模型的本地目标参数;
发送所述本地目标参数,所述本地目标参数用于供第二设备对所述全局模型进行参数更新,完成参数更新的全局模型用于识别待识别数据中的目标对象。
2.根据权利要求1所述的方法,所述本地数据集包括多个训练图像集,其中,所述基于本地数据集,得到所述本地模型的本地动量参数,包括:
基于每次训练下采用的训练图像集,得到每次训练下所述本地模型的图像梯度,将所述图像梯度作为本地梯度参数;
基于每次训练下所述本地模型的本地梯度参数,对每次训练对应的参考动量参数进行调整,得到每次训练下所述本地模型的本地动量参数。
3.根据权利要求2所述的方法,其中,所述根据所述本地动量参数对所述本地模型进行训练,得到所述本地模型的本地目标参数,包括:
基于每次训练下所述本地模型的本地动量参数,对每次训练下所述本地模型的初始模型参数进行调整,得到每次训练下所述本地模型的候选参数;
基于每次训练下所述本地模型的候选参数,得到所述本地模型的本地目标参数。
4.根据权利要求3所述的方法,其中,所述本地目标参数为多次训练中的最后一次训练下而得到的候选参数。
5.根据权利要求2或3或4所述的方法,其中,
在多次训练的非首次训练下,所述本地模型的初始模型参数为所述非首次训练的前一次训练下而得到的所述本地模型的候选参数;
在多次训练的首次训练下,所述本地模型的初始模型参数基于所述更新指令而得到。
6.根据权利要求1至5中任一项所述的方法,其中,所述完成参数更新的全局模型用于识别待识别图像中的目标对象。
7.根据权利要求1至5任一项所述的方法,其中,所述完成参数更新的全局模型通过采用全局动量参数对所述全局模型的全局参考参数进行调整而得到;
其中,所述全局动量参数基于至少一个所述第一设备发送的本地目标参数而得到的。
8.根据权利要求7所述的方法,其中,所述全局动量参数基于全局参考动量和全局梯度而得到,其中所述全局梯度基于所述至少一个所述第一设备的本地目标参数之间的聚合结果而得到。
9.一种机器学习模型训练方法,应用于第二设备中,所述方法包括:
接收针对全局模型的更新指令而产生的多个本地目标参数,所述多个本地目标参数为各第一设备根据权利要求1-5中任一项所述的方法而得到;
基于对所述多个本地目标参数的聚合结果,得到全局模型的全局动量参数;
采用全局动量参数,对所述全局模型进行参数更新,得到目标全局模型,所述目标全局模型用于识别待识别数据中的目标对象。
10.根据权利要求9所述的方法,其中,所述聚合结果通过采用预设聚合算法对所述多个本地目标参数进行聚合而得到。
11.根据权利要求9或10所述的方法,其中,所述基于对所述多个本地目标参数的聚合结果,得到全局模型的全局动量参数,包括:
基于所述多个本地目标参数的聚合结果,得到所述全局模型的全局梯度;
基于所述全局梯度和全局参考动量,得到全局模型的全局动量参数。
12.根据权利要求11所述的方法,其中,所述采用全局动量参数,对所述全局模型进行参数更新,得到目标全局模型,包括:
采用全局动量参数,对全局模型的全局参考参数进行调整,得到全局目标参数;
基于所述全局目标参数,得到目标全局模型。
13.一种识别方法,包括:
获得待识别数据;
将所述待识别数据输入至权利要求9至12中任一项所述的目标全局模型,得到所述待识别数据中的目标对象。
14.一种机器学习模型训练装置,包括:
接收单元,用于接收对全局模型的更新指令,所述指令用于触发本地模型的训练;
第一获得单元,用于响应所述更新指令,基于本地数据集,得到所述本地模型的本地动量参数;
第二获得单元,用于根据所述本地动量参数对所述本地模型进行训练,得到所述本地模型的本地目标参数;
发送单元,用于发送所述本地目标参数,所述本地目标参数用于供第二设备对所述全局模型进行参数更新,完成参数更新的全局模型用于识别待识别数据中的目标对象。
15.根据权利要求14所述的装置,其中,所述本地数据集包括多个训练图像集,所述第一获得单元,用于:
基于每次训练下采用的训练图像集,得到每次训练下所述本地模型的图像梯度,将所述图像梯度作为本地梯度参数;
基于每次训练下所述本地模型的本地梯度参数,对每次训练对应的参考动量参数进行调整,得到每次训练下所述本地模型的本地动量参数。
16.根据权利要求15所述的装置,其中,所述第二获得单元,用于:
基于每次训练下采用的训练图像集,得到每次训练下所述本地模型的图像梯度,将所述图像梯度作为本地梯度参数;
基于每次训练下所述本地模型的本地梯度参数,对每次训练对应的参考动量参数进行调整,得到每次训练下所述本地模型的本地动量参数。
17.根据权利要求16所述的装置,其中,所述本地目标参数为多次训练中的最后一次训练下而得到的候选参数。
18.根据权利要求15或16或17所述的装置,其中,
在多次训练的非首次训练下,所述本地模型的初始模型参数为所述非首次训练的前一次训练下而得到的所述本地模型的候选参数;
在多次训练的首次训练下,所述本地模型的初始模型参数基于所述更新指令而得到。
19.根据权利要求14至18中任一项所述的装置,其中,所述完成参数更新的全局模型用于识别待识别图像中的目标对象。
20.根据权利要求14至18中任一项所述的装置,其中,所述完成参数更新的全局模型通过采用全局动量参数对所述全局模型的全局参考参数进行调整而得到;
其中,所述全局动量参数基于至少一个接收到的本地目标参数而得到的。
21.根据权利要求20所述的装置,其中,所述全局动量参数基于全局参考动量和全局梯度而得到,其中所述全局梯度基于所述至少一个所述第一设备的本地目标参数之间的聚合结果而得到。
22.一种机器学习模型训练装置,包括:
接收模块,用于接收针对全局模型的更新指令而产生的多个本地目标参数,所述本地目标参数为权利要求14-18任一所述装置中的本地目标参数;
第一获得模块,用于基于对所述多个本地目标参数的聚合结果,得到全局模型的全局动量参数;
第二获得模块,用于采用全局动量参数,对所述全局模型进行参数更新,得到目标全局模型,所述目标全局模型用于识别待识别数据中的目标对象。
23.根据权利要求22所述的装置,其中,所述聚合结果通过采用预设聚合算法对所述多个本地目标参数进行聚合而得到。
24.根据权利要求22或23所述的装置,其中,所述第一获得模块,用于:
基于所述多个本地目标参数的聚合结果,得到所述全局模型的全局梯度;
基于所述全局梯度和全局参考动量,得到全局模型的全局动量参数。
25.根据权利要求24所述的装置,其中,所述第二获得模块,用于:
采用全局动量参数,对全局模型的全局参考参数进行调整,得到全局目标参数;
基于所述全局目标参数,得到目标全局模型。
26.一种识别装置,包括:
第一获得单元,用于获得待识别数据;
第二获得单元,用于将所述待识别数据输入至权利要求22至25中任一项所述的目标全局模型,得到所述待识别数据中的目标对象。
27.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-8和/或权利要求9-13中任一项所述的方法。
28.一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使计算机执行权利要求1-8和/或权利要求9-13中任一项所述的方法。
29.一种计算机程序产品,包括计算机程序,该计算机程序在被处理器执行时实现权利要求1-8和/或权利要求9-13中任一项所述的方法。
CN202211282362.5A 2022-10-19 2022-10-19 机器学习模型训练方法、识别方法、相关装置及电子设备 Withdrawn CN115600693A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211282362.5A CN115600693A (zh) 2022-10-19 2022-10-19 机器学习模型训练方法、识别方法、相关装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211282362.5A CN115600693A (zh) 2022-10-19 2022-10-19 机器学习模型训练方法、识别方法、相关装置及电子设备

Publications (1)

Publication Number Publication Date
CN115600693A true CN115600693A (zh) 2023-01-13

Family

ID=84849566

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211282362.5A Withdrawn CN115600693A (zh) 2022-10-19 2022-10-19 机器学习模型训练方法、识别方法、相关装置及电子设备

Country Status (1)

Country Link
CN (1) CN115600693A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116187473A (zh) * 2023-01-19 2023-05-30 北京百度网讯科技有限公司 联邦学习方法、装置、电子设备和计算机可读存储介质

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116187473A (zh) * 2023-01-19 2023-05-30 北京百度网讯科技有限公司 联邦学习方法、装置、电子设备和计算机可读存储介质
CN116187473B (zh) * 2023-01-19 2024-02-06 北京百度网讯科技有限公司 联邦学习方法、装置、电子设备和计算机可读存储介质

Similar Documents

Publication Publication Date Title
US11941527B2 (en) Population based training of neural networks
TWI767000B (zh) 產生波形之方法及電腦儲存媒體
JP2019071080A (ja) バッチ正規化レイヤ
JP6896176B2 (ja) システム強化学習方法及び装置、電子機器、コンピュータ記憶媒体並びにコンピュータプログラム
KR20200031163A (ko) 신경 네트워크 구조의 생성 방법 및 장치, 전자 기기, 저장 매체
WO2017091629A1 (en) Reinforcement learning using confidence scores
CN114065863B (zh) 联邦学习的方法、装置、系统、电子设备及存储介质
CN113657289B (zh) 阈值估计模型的训练方法、装置和电子设备
US20220148239A1 (en) Model training method and apparatus, font library establishment method and apparatus, device and storage medium
US11875584B2 (en) Method for training a font generation model, method for establishing a font library, and device
CN113610989B (zh) 风格迁移模型训练方法和装置、风格迁移方法和装置
US20220237935A1 (en) Method for training a font generation model, method for establishing a font library, and device
CN113657483A (zh) 模型训练方法、目标检测方法、装置、设备以及存储介质
CN115147680B (zh) 目标检测模型的预训练方法、装置以及设备
CN114020950A (zh) 图像检索模型的训练方法、装置、设备以及存储介质
US20240070454A1 (en) Lightweight model training method, image processing method, electronic device, and storage medium
US20220398834A1 (en) Method and apparatus for transfer learning
CN115600693A (zh) 机器学习模型训练方法、识别方法、相关装置及电子设备
CN113344213A (zh) 知识蒸馏方法、装置、电子设备及计算机可读存储介质
CN114926322B (zh) 图像生成方法、装置、电子设备和存储介质
CN114078184B (zh) 数据处理方法、装置、电子设备和介质
CN115880506A (zh) 图像生成方法、模型的训练方法、装置及电子设备
CN114067415A (zh) 回归模型的训练方法、对象评估方法、装置、设备和介质
CN116797829B (zh) 一种模型生成方法、图像分类方法、装置、设备及介质
CN114239608B (zh) 翻译方法、模型训练方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
WW01 Invention patent application withdrawn after publication

Application publication date: 20230113

WW01 Invention patent application withdrawn after publication