CN117057442A - 一种基于联邦多任务学习的模型训练方法、装置及设备 - Google Patents

一种基于联邦多任务学习的模型训练方法、装置及设备 Download PDF

Info

Publication number
CN117057442A
CN117057442A CN202311298511.1A CN202311298511A CN117057442A CN 117057442 A CN117057442 A CN 117057442A CN 202311298511 A CN202311298511 A CN 202311298511A CN 117057442 A CN117057442 A CN 117057442A
Authority
CN
China
Prior art keywords
client
clients
model parameters
parameters corresponding
optimization model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202311298511.1A
Other languages
English (en)
Inventor
黄章敏
孙辰晨
戴雨洋
李勇
张莹
程稳
陈�光
曾令仿
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Lab
Original Assignee
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Lab filed Critical Zhejiang Lab
Priority to CN202311298511.1A priority Critical patent/CN117057442A/zh
Publication of CN117057442A publication Critical patent/CN117057442A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Software Systems (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Medical Informatics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Molecular Biology (AREA)
  • Databases & Information Systems (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

本说明书公开了一种基于联邦多任务学习的模型训练方法、装置及设备,中心服务器将各客户端对应的初始模型参数发送给各客户端,以使各客户端对基于各自的初始模型参数得到的模型进行训练,并将训练后的模型的优化模型参数返回给中心服务器,中心服务器根据各客户端对应的优化模型参数,确定各客户端对应的对优化模型参数进行加权的权重,并根据各客户端对应的对各优化模型参数进行加权的权重,确定适用于各客户端的模型参数,得到适用于各客户端的模型。由于各客户端的数据分布存在差异,因此本方法在模型的每次迭代训练过程中,根据权重确定各客户端的模型参数,使得各客户端得到更加泛化的模型的同时,可得到适用于各自数据分布的个性化模型。

Description

一种基于联邦多任务学习的模型训练方法、装置及设备
技术领域
本申请涉及计算机技术领域,尤其涉及一种基于联邦多任务学习的模型训练方法、装置及设备。
背景技术
随着科技的飞速发展,机器学习模型得到广泛应用。其中,联邦学习是模型的分布式训练中的一种,联邦学习框架中包括一个中心服务器和多个参与训练的客户端,其主要目的是保护各客户端中的隐私数据不泄露。
一般的,在实际应用场景中,各客户端的数据分布不同,也即每个客户端的数据都具一定特点。例如:在对一本小说进行打分时,由于每个人的喜好、教育背景的差异等,对该本小说的打分不同,并且不同客户端适应的人群不同,那么每个客户端中的数据分布是不同的,具有其自身的特性,因此,如何基于联邦学习进行模型的训练,以使得各参与训练的客户端可以得到适用于其本身数据分布的更加精准的模型是一个重点问题。
基于此,本申请说明书提供了一种基于联邦多任务学习的模型训练方法。
发明内容
本说明书提供一种基于联邦多任务学习的模型训练方法、装置、存储介质及电子设备,以至少部分的解决现有技术存在的上述问题。
本说明书采用下述技术方案:
本说明书提供了一种基于联邦多任务学习的模型训练方法,所述方法应用于分布式系统中的中心服务器,所述方法包括:
向各客户端发送所述各客户端各自对应的初始模型参数,以使所述各客户端分别对根据各自对应的初始模型参数得到的待训练模型进行训练,得到所述各客户端对应的待训练模型的优化模型参数;
接收所述各客户端发送的所述各客户端对应的优化模型参数;
针对每个客户端,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重;
根据确定出的该客户端对其他客户端对应的优化模型参数进行加权的权重以及所述各客户端对应的优化模型参数,确定该客户端对应的加权优化参数;
将确定出的各客户端对应的加权优化参数重新作为所述各客户端各自对应的初始模型参数,并分别发送至所述各客户端,以使所述各客户端继续对所述各客户端各自对应的待训练模型进行训练。
可选地,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重,具体包括:
确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度;
根据确定出的该客户端对应的各相似度,确定该客户端对其他客户端对应的优化模型参数进行加权的权重。
可选地,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度,具体包括:
确定该客户端对应的优化模型参数的参数向量与其他客户端对应的优化模型参数的参数向量之间的欧氏距离矩阵;
根据确定出的该客户端对应的欧氏距离矩阵,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度。
可选地,确定该客户端对其他客户端对应的优化模型参数进行加权的权重,具体包括:
对该客户端对应的各相似度进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重。
可选地,对该客户端对应的各相似度进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重,具体包括:
使用softmax函数对该客户端对应的各相似度进行归一化,得到所述softmax函数输出各相似度的概率;
将所述softmax函数输出各相似度的概率作为该客户端对其他客户端对应的优化模型参数进行加权的权重。
本说明书提供了一种基于联邦多任务学习的模型训练装置,所述装置应用于分布式系统中的中心服务器,包括:
发送模块,用于向各客户端发送所述各客户端各自对应的初始模型参数,以使所述各客户端分别对根据各自对应的初始模型参数得到的待训练模型进行训练,得到所述各客户端对应的待训练模型的优化模型参数;
接收模块,用于接收所述各客户端发送的所述各客户端对应的优化模型参数;
确定模块,用于针对每个客户端,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重;
加权模块,用于根据确定出的该客户端对其他客户端对应的优化模型参数进行加权的权重以及所述各客户端对应的优化模型参数,确定该客户端对应的加权优化参数;
训练模块,用于将确定出的各客户端对应的加权优化参数重新作为所述各客户端各自对应的初始模型参数,并分别发送至所述各客户端,以使所述各客户端继续对所述各客户端各自对应的待训练模型进行训练。
可选地,所述确定模块具体用于,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度;根据确定出的该客户端对应的各相似度,确定该客户端对其他客户端对应的优化模型参数进行加权的权重。
可选地,所述确定模块具体用于,确定该客户端对应的优化模型参数的参数向量与其他客户端对应的优化模型参数的参数向量之间的欧氏距离矩阵;根据确定出的该客户端对应的欧氏距离矩阵,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度。
可选地,所述确定模块具体用于,对该客户端对应的各相似度进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重。
可选地,所述确定模块具体用于,使用softmax函数对该客户端对应的各相似度进行归一化,得到所述softmax函数输出各相似度的概率;将所述softmax函数输出各相似度的概率作为该客户端对其他客户端对应的优化模型参数进行加权的权重。
本说明书提供了一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述基于联邦多任务学习的模型训练方法。
本说明书提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述基于联邦多任务学习的模型训练方法。
本说明书采用的上述至少一个技术方案能够达到以下有益效果:
在本说明书提供的基于联邦多任务学习的模型训练方法中可以看出,在进行基于联邦多任务学习的模型训练时,中心服务器将各客户端对应的初始模型参数发送给各客户端,以使各客户端对基于各自的初始模型参数得到的模型进行训练,并将训练后的模型的优化模型参数返回给中心服务器,中心服务器可根据各客户端对应的优化模型参数,确定各客户端对应的对优化模型参数进行加权的权重,进而可根据各客户端对应的对各优化模型参数进行加权的权重,确定适用于各客户端的模型参数,从而得到适用于各客户端的模型。
区别于目前的直接对各优化模型参数取平均值以发送给各客户端的方法,本方法可以使得各客户端得到适用于各自数据分布的个性化模型,因为不同的客户端中的用户不同,用户的喜好、教育背景的差异等会使得各客户端的数据分布不同,那么各客户端的模型学习数据特征时的侧重不同,因此本方法在每次迭代训练过程中,确定对各客户端对应的对优化模型参数进行加权的权重,从而确定各客户端的模型参数,使得各客户端通过联邦学习实现间接的样本数据分享,得到更加泛化的模型的同时,可以得到适用于各自数据分布的更加精确的模型。
附图说明
此处所说明的附图用来提供对本说明书的进一步理解,构成本说明书的一部分,本说明书的示意性实施例及其说明用于解释本说明书,并不构成对本说明书的不当限定。在附图中:
图1为本说明书中一种基于联邦多任务学习的模型训练方法的流程示意图;
图2为本说明书提供的基于各客户端对应的模型参数确定出的欧氏距离矩阵;
图3为本说明书提供的基于联邦多任务学习的模型训练方法的框架示意图;
图4为本说明书提供的一种基于联邦多任务学习的模型训练装置示意图;
图5为本说明书提供的对应于图1的电子设备示意图。
具体实施方式
联邦学习广泛应用于各种各样的场景中,例如:在大数据医疗领域,在训练机器学习模型时,为了获取足够多的样本数据,需要各医院共享其本地数据以构建数据集,进行模型的训练,然而由于各家医疗机构保存的电子病历、医学图像(如:CT图像、MRI图像)等数据,涉及到就诊人员的隐私,因此无法实现数据共享。而联邦学习算法可以解决这一问题,联邦学习是模型的分布式训练中的一种,联邦学习框架中包括一个中心服务器和多个参与训练的客户端,其主要目的是保护各客户端中的隐私数据不泄露。具体的,联邦学习框架中的各个客户端可接收中心服务器发送的模型参数,针对每个客户端,该客户端可以以该模型参数生成机器学习模型,并将该客户端本地保存的隐私数据作为训练样本输入该机器学习模型,根据机器学习模型输出的结果和训练样本对应的标注得到梯度(或者新的模型参数),再将得到的梯度上传给中心服务器,以使中心服务器对各客户端发送的梯度取平均并进行模型参数的更新,以此迭代。因此使用联邦学习算法训练模型,在各医院无需共享本地的隐私数据的情况下,可实现基于各医院的隐私数据对机器学习模型进行训练的目标,且可以有效地保护各客户端本地的就诊人员的隐私数据。
此外,由于各医院的就诊人员的年龄、运动频率、生活习惯等存在差异,往往各医院的数据分布不同,都具有其自身特点,因此在基于联邦学习训练模型时,还要考虑如何保留各医院的数据的特点,使得训练出的模型可以更好的拟合每个医院本地的隐私数据,以提高模型的精确度。基于此,本申请说明书提供了一种基于联邦多任务学习的模型训练方法,可以使得各客户端通过联邦多任务学习实现间接的样本数据分享,得到更加泛化的模型的同时,可以得到适用于各客户端自身数据分布的更加精确的模型。
为使本说明书的目的、技术方案和优点更加清楚,下面将结合本说明书具体实施例及相应的附图对本说明书技术方案进行清楚、完整地描述。显然,所描述的实施例仅是本说明书一部分实施例,而不是全部的实施例。基于本说明书中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于说明书保护的范围。
以下结合附图,详细说明本说明书各实施例提供的技术方案。
图1为本说明书提供的一种基于联邦多任务学习的模型训练方法的流程示意图,具体可包括以下步骤:
S100:向各客户端发送所述各客户端各自对应的初始模型参数,以使所述各客户端分别对根据各自对应的初始模型参数得到的待训练模型进行训练,得到所述各客户端对应的待训练模型的优化模型参数。
一般的,在医疗领域中在进行医学图像分类时,各医院的医疗图像数据是保密的,且由于各医院地理位置、环境等因素以及各医院的就诊人员的年龄、性别等差异,各家医院都具有适用于其自身数据特点的分类标准,因此各医院的医疗图像的分类类别是不平衡的,也即各医院确定各医学图像的分类标签是不同的。
因此在对模型进行训练时,为了使得各客户端的模型在学习数据特征时更侧重于学习各客户端自身的数据特征,也即为了训练出适用于各医院自身数据分布的图像分类模型,在本说明书中,各医院的本地客户端可接收中心服务器发送的初始模型参数,并且每个客户端可分别对各自对应的初始模型参数得到的待训练模型进行训练,得到各客户端对应的待训练模型的优化模型参数。
具体的,中心服务器可先确定待训练的图像分类模型的初始模型参数,并向客户端发送各客户端对应的初始模型参数。在本说明书中,在首轮迭代时,各客户端的初始模型参数可相同。进而,各客户端可根据各自接收到的初始模型参数,确定出待训练的图像分类模型,并使用各客户端本地数据对各客户端对应的待训练的图像分类模型进行训练,得到各到训练后的图像分类模型,并确定该训练后的图像分类模型中的模型参数,作为优化模型参数。
在本说明书的一个或多个实施例中,该图像分类模型可为卷积神经网络模型。
S102:接收所述各客户端发送的所述各客户端对应的优化模型参数。
S104:针对每个客户端,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重。
则在各客户端对待训练的模型进行训练得到优化模型参数之后,各客户端可将各自对应的优化模型参数返回给中心服务器,如图2所示,为本说明书提供的基于联邦多任务学习的模型训练方法的框架示意图。
在本说明书中,为了使得待训练的模型可以更好的拟合各客户端本地的数据,中心服务器可接收各客户端发送的各客户端对应的优化模型参数,并且针对每个客户端,中心服务器可根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重。
具体的,针对每个客户端,在确定该客户端对其他客户端对应的优化模型参数进行加权的权重时,中心服务器可通过确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度,以根据确定出的该客户端对应的各相似度,确定该客户端对其他客户端对应的优化模型参数进行加权的权重。
在本说明书的一个或多个实施例中,在确定客户端之间的优化模型参数之间的相似度时,可针对每个客户端,确定该客户端对应的优化模型参数的参数向量与其他客户端对应的优化模型参数的参数向量之间的欧氏距离,得到欧氏距离矩阵,进而根据确定出的该客户端对应的各欧氏距离矩阵,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度。
此外,在本说明书的一个或多个实施例中,由于两个客户端对应的参数向量之间的距离越大,表征两个客户端中的数据特征差异越大,也即两个客户端的数据特征之间的帮助程度越小,因此中心服务器可针对每个客户端,对该客户端对应的欧氏距离矩阵中的每一个值添加一个负号,表征距离与相似度之间的负相关趋势。当然,也可使用其他方法,具体本说明书不做限制。
S106:根据确定出的该客户端对其他客户端对应的优化模型参数进行加权的权重以及所述各客户端对应的优化模型参数,确定该客户端对应的加权优化参数。
进而中心服务器可根据确定出的该客户端对其他客户端对应的优化模型参数进行加权的权重以及各客户端对应的优化模型参数,确定该客户端对应的加权优化参数。
具体的,在本说明书的一个或多个实施例中,中心服务器可针对每个客户端,对该客户端对应的各相似度进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重。并且,中心服务器可使用softmax函数对该客户端对应的各相似度进行归一化,得到softmax函数输出各相似度的概率,将softmax函数输出各相似度的概率作为该客户端对其他客户端对应的优化模型参数进行加权的权重。
也就是说,在上述步骤S104中,针对每个客户端,中心服务器在对该客户端对应的欧氏距离矩阵中每一项添加负号之后,可将添加负号之后的每一项输入softmax函数,得到softmax函数输出各添加负号之后的每一项的概率,并将该概率作为权重。对欧氏距离矩阵中每一项添加负号之后,softmax函数输出的每一项对应的概率值就小,与距离与相似度之间的负相关相符。如图3所示,为本说明书提供的基于各客户端对应的模型参数确定出的欧氏距离矩阵,每一行代表一个客户端的欧氏距离矩阵,以客户端1为例,则第一行为客户端1与其他客户端之间的欧氏距离矩阵,即[2 7](除去客户端1本身),该矩阵中的每一项代表该客户端1与其他客户端(即客户端2与客户端3)之间的欧氏距离。
则在本说明书中,针对每个客户端,中心服务器可使用该客户端对其他客户端对应的优化模型参数进行加权的权重,对其他客户端对应的优化模型参数进行加权求和,得到该客户端对应的加权优化参数。
需要说明的是,在本说明书的一个或多个实施例中,softmax函数是随着待训练模型的每次迭代进行更新的,也即在各客户端将优化模型参数发送给中心服务器后,该中心服务器可根据接收到的各客户端对应的优化模型参数,更新softmax函数,得到各客户端对应的softmax函数。因此,在确定该客户端对应的加权优化参数时,中心服务器可针对每个客户端,根据该客户端对应的优化模型参数对softmax函数进行更新,并使用更新后的softmax函数对该客户端对应的各相似度进行归一化,得到softmax函数输出各相似度的概率,并将softmax函数输出各相似度的概率作为该客户端对其他客户端对应的优化模型参数进行加权的权重。
S108:将确定出的各客户端对应的加权优化参数重新作为所述各客户端各自对应的初始模型参数,并分别发送至所述各客户端,以使所述各客户端继续对所述各客户端各自对应的待训练模型进行训练。
中心服务器确定出各客户端对其他客户端对应的优化模型参数进行加权的权重之后,可将确定出的各客户端对应的加权优化参数重新作为各客户端各自对应的初始模型参数,并分别发送至各客户端,以使各客户端继续对各客户端各自对应的待训练模型进行训练。
此外,在本说明书一个或多个实施例中,待训练的模型可为图像分类模型,在客户端对待训练的图像分类模型进行训练时,可先基于本地数据确定各样本图像以及各样本图像对应的分类标签,便可根据各样本图像及其对应的分类标签,对待训练的图像分类检测模型进行训练。也即客户端可将样本图像输入待训练的图像分类模型,得到待训练的图像分类模型输出的预测结果,然后,将样本图像对应的预测结果以及样本图像对应的分类标签,输入到损失函数中,根据损失函数计算损失,并确定使损失最小的梯度,根据梯度下降的方向来调整图像分类模型的模型参数,使样本图像对应的预测结果与样本图像对应的分类标签之间的差异最小。按照上述方法,进行待训练的图像分类模型的迭代训练。当然,具体何时确定该图像分类模型的训练结束,本说明书不做具体限制,例如当训练迭代次数达到预设阈值时,确定该图像分类模型的训练结束,或者当确定出的损失小于预设数值时,确定该图像分类模型的训练结束。
基于图1所示本说明书提供的上述基于联邦多任务学习的模型训练方法中,在进行基于联邦多任务学习的模型训练时,中心服务器将各客户端对应的初始模型参数发送给各客户端,以使各客户端对基于各自的初始模型参数得到的模型进行训练,并将训练后的模型的优化模型参数返回给中心服务器,中心服务器可根据各客户端对应的优化模型参数,确定各客户端对应的对优化模型参数进行加权的权重,进而可根据各客户端对应的对各优化模型参数进行加权的权重,确定适用于各客户端的模型参数,从而得到适用于各客户端的模型。区别于目前的直接对各优化模型参数取平均值以发送给各客户端的方法,本方法可以使得各客户端得到适用于各自数据分布的个性化模型,因为不同的客户端中的用户不同,用户的喜好、教育背景的差异等会使得各客户端的数据分布不同,那么各客户端中的模型学习数据特征时的侧重也不同,因此本方法在每次迭代训练过程中,确定对各客户端对应的对优化模型参数进行加权的权重,从而确定各客户端的模型参数,使得在各客户端可通过数据分享得到更加泛化的模型的同时,可以得到适用于各自数据分布的更加精确的个性化模型。
另外,在进行本说明书提供的上述模型的训练方法时,可先通过实验来验证该基于联邦多任务学习的模型训练方法,以及验证本申请说明书提供的上述方法的优点。其中,实验时的试验数据集可采用机器学习中常用的经典数据集,如:CIFAR-10、CIFAR-100、ImageNet等,并且可利用深度学习开源框架PyTorch,进行实验的计算设备的配置可包括:CPU-i7、内存64G、NVIDIA V100 GPU等。
进一步的,在上述步骤S104中,针对每个客户端,在中心服务器确定该客户端对其他客户端对应的优化模型参数进行加权的权重时,可使用其他方法,只要能够使得输出的权重与相似度呈正相关、与优化模型参数之间的距离呈负相关即可。例如:可设定最大距离阈值,使用该最大距离阈值减去该客户端对应的欧氏距离矩阵中的每一项,得到的差值为该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的相似度矩阵,进而可对该客户端对应的相似度矩阵中的每一项进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重。
此外,在本说明书的一个或多个实施例中,无论是上述步骤S104中还是上述步骤S106中,针对每个客户端,在确定该客户端对其他客户端对应的优化模型参数进行加权时,可将该客户端对该客户端对应的优化模型参数进行加权的权重设为1。沿用上述步骤S106中的图3,对于客户端1而言,其与其他客户端之间的欧氏距离矩阵为[27],进一步假设,对欧氏距离矩阵中的每一项添加负号之后,将[-2-7]输入softmax函数得到的概率为:0.6与0.4,那么客户端1对应的加权优化参数即为:(1×客户端1的优化模型参数+0.6×客户端2的优化模型参数+0.4×客户端3的优化模型参数)/2。
相应的,客户端1的对自身的优化模型参数加权的权重也可经由softmax函数得到,即不设为1。则在本说明书的一个或多个实施例中,在上述步骤S104中,在针对每个客户端,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重时,可根据该客户端对应的优化模型参数,确定该客户端对各客户端对应的优化模型参数进行加权的权重,也即也要确定该客户端对该客户端对应的优化模型参数进行加权的权重。
在确定该客户端对各客户端对应的优化模型参数进行加权的权重时具体可为:中心服务器可通过确定该客户端对应的优化模型参数与各客户端对应的优化模型参数之间的各相似度,以根据确定出的该客户端对应的各相似度,确定该客户端对各客户端对应的优化模型参数进行加权的权重。在确定客户端之间的优化模型参数之间的相似度时,可针对每个客户端,确定该客户端对应的优化模型参数的参数向量与各客户端对应的优化模型参数的参数向量之间的欧氏距离,得到欧氏距离矩阵,进而根据确定出的该客户端对应的各欧氏距离矩阵,确定该客户端对应的优化模型参数与各客户端对应的优化模型参数之间的各相似度。
沿用上述步骤S106中的图3,对于客户端1而言,客户端1与各客户端之间的欧氏距离矩阵为[0 27]。
在上述步骤S106中,在针对每个客户端,在根据确定出的该客户端对其他客户端对应的优化模型参数进行加权的权重以及各客户端对应的优化模型参数,确定该客户端对应的加权优化参数时,可为针对每个客户端,根据确定出的该客户端对各客户端对应的优化模型参数进行加权的权重以及各客户端对应的优化模型参数,确定该客户端对应的加权优化参数。
沿用上述步骤S106中的图3,进一步假设,对欧氏距离矩阵中的每一项添加负号之后,将[0 -2-7]输入softmax函数得到[0.7 0.20.1],即得到的各权重分别为:0.7、0.2以及0.1,那么客户端1对应的加权优化参数即为:0.7×客户端1的优化模型参数+0.2×客户端2的优化模型参数+0.1×客户端3的优化模型参数。也即在本说明书中,针对每个客户端,中心服务器可使用softmax函数对该客户端对应的各相似度进行归一化,得到softmax函数输出各相似度的概率,将softmax函数输出各相似度的概率作为该客户端对各客户端对应的优化模型参数进行加权的权重,并使用该客户端对各客户端对应的优化模型参数进行加权的权重,对各客户端对应的优化模型参数进行加权求和,得到该客户端对应的加权优化参数。
基于上述内容所述的基于联邦多任务学习的模型训练方法,本说明书实施例还对应的提供一种用于基于联邦多任务学习的模型训练装置示意图,如图4所示。
图4为本说明书实施例提供的一种用于基于联邦多任务学习的模型训练装置的示意图,所述装置包括:
发送模块400,用于向各客户端发送所述各客户端各自对应的初始模型参数,以使所述各客户端分别对根据各自对应的初始模型参数得到的待训练模型进行训练,得到所述各客户端对应的待训练模型的优化模型参数;
接收模块402,用于接收所述各客户端发送的所述各客户端对应的优化模型参数;
确定模块404,用于针对每个客户端,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重;
加权模块406,用于根据确定出的该客户端对其他客户端对应的优化模型参数进行加权的权重以及所述各客户端对应的优化模型参数,确定该客户端对应的加权优化参数;
训练模块408,用于将确定出的各客户端对应的加权优化参数重新作为所述各客户端各自对应的初始模型参数,并分别发送至所述各客户端,以使所述各客户端继续对所述各客户端各自对应的待训练模型进行训练。
可选地,所述确定模块404具体用于,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度;根据确定出的该客户端对应的各相似度,确定该客户端对其他客户端对应的优化模型参数进行加权的权重。
可选地,所述确定模块404具体用于,确定该客户端对应的优化模型参数的参数向量与其他客户端对应的优化模型参数的参数向量之间的欧氏距离矩阵;根据确定出的该客户端对应的欧氏距离矩阵,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度。
可选地,所述确定模块404具体用于,对该客户端对应的各相似度进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重。
可选地,所述确定模块404具体用于,使用softmax函数对该客户端对应的各相似度进行归一化,得到所述softmax函数输出各相似度的概率;将所述softmax函数输出各相似度的概率作为该客户端对其他客户端对应的优化模型参数进行加权的权重。
本说明书实施例还提供了一种计算机可读存储介质,该存储介质存储有计算机程序,计算机程序可用于执行上述内容所述的基于联邦多任务学习的模型训练方法。
基于上述内容所述的基于联邦多任务学习的模型训练方法,本说明书实施例还提出了图5所示的电子设备的示意结构图。如图5,在硬件层面,该电子设备包括处理器、内部总线、网络接口、内存以及非易失性存储器,当然还可能包括其他业务所需要的硬件。处理器从非易失性存储器中读取对应的计算机程序到内存中然后运行,以实现上述内容所述的基于联邦多任务学习的模型训练方法。
当然,除了软件实现方式之外,本说明书并不排除其他实现方式,比如逻辑器件抑或软硬件结合的方式等等,也就是说以下处理流程的执行主体并不限定于各个逻辑单元,也可以是硬件或逻辑器件。
在20世纪90年代,对于一个技术的改进可以很明显地区分是硬件上的改进(例如,对二极管、晶体管、开关等电路结构的改进)还是软件上的改进(对于方法流程的改进)。然而,随着技术的发展,当今的很多方法流程的改进已经可以视为硬件电路结构的直接改进。设计人员几乎都通过将改进的方法流程编程到硬件电路中来得到相应的硬件电路结构。因此,不能说一个方法流程的改进就不能用硬件实体模块来实现。例如,可编程逻辑器件(Programmable Logic Device, PLD)(例如现场可编程门阵列(Field Programmable GateArray,FPGA))就是这样一种集成电路,其逻辑功能由用户对器件编程来确定。由设计人员自行编程来把一个数字系统“集成”在一片PLD上,而不需要请芯片制造厂商来设计和制作专用的集成电路芯片。而且,如今,取代手工地制作集成电路芯片,这种编程也多半改用“逻辑编译器(logic compiler)”软件来实现,它与程序开发撰写时所用的软件编译器相类似,而要编译之前的原始代码也得用特定的编程语言来撰写,此称之为硬件描述语言(Hardware Description Language,HDL),而HDL也并非仅有一种,而是有许多种,如ABEL(Advanced Boolean Expression Language)、AHDL(Altera Hardware DescriptionLanguage)、Confluence、CUPL(Cornell University Programming Language)、HDCal、JHDL(Java Hardware Description Language)、Lava、Lola、MyHDL、PALASM、RHDL(RubyHardware Description Language)等,目前最普遍使用的是VHDL(Very-High-SpeedIntegrated Circuit Hardware Description Language)与Verilog。本领域技术人员也应该清楚,只需要将方法流程用上述几种硬件描述语言稍作逻辑编程并编程到集成电路中,就可以很容易得到实现该逻辑方法流程的硬件电路。
控制器可以按任何适当的方式实现,例如,控制器可以采取例如微处理器或处理器以及存储可由该(微)处理器执行的计算机可读程序代码(例如软件或固件)的计算机可读介质、逻辑门、开关、专用集成电路(Application Specific Integrated Circuit,ASIC)、可编程逻辑控制器和嵌入微控制器的形式,控制器的例子包括但不限于以下微控制器:ARC 625D、Atmel AT91SAM、Microchip PIC18F26K20 以及Silicone Labs C8051F320,存储器控制器还可以被实现为存储器的控制逻辑的一部分。本领域技术人员也知道,除了以纯计算机可读程序代码方式实现控制器以外,完全可以通过将方法步骤进行逻辑编程来使得控制器以逻辑门、开关、专用集成电路、可编程逻辑控制器和嵌入微控制器等的形式来实现相同功能。因此这种控制器可以被认为是一种硬件部件,而对其内包括的用于实现各种功能的装置也可以视为硬件部件内的结构。或者甚至,可以将用于实现各种功能的装置视为既可以是实现方法的软件模块又可以是硬件部件内的结构。
上述实施例阐明的系统、装置、模块或单元,具体可以由计算机芯片或实体实现,或者由具有某种功能的产品来实现。一种典型的实现设备为计算机。具体的,计算机例如可以为个人计算机、膝上型计算机、蜂窝电话、相机电话、智能电话、个人数字助理、媒体播放器、导航设备、电子邮件设备、游戏控制台、平板计算机、可穿戴设备或者这些设备中的任何设备的组合。
为了描述的方便,描述以上装置时以功能分为各种单元分别描述。当然,在实施本说明书时可以把各单元的功能在同一个或多个软件和/或硬件中实现。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
在一个典型的配置中,计算设备包括一个或多个处理器(CPU)、输入/输出接口、网络接口和内存。
内存可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM)。内存是计算机可读介质的示例。
计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括暂存电脑可读媒体(transitory media),如调制的数据信号和载波。
还需要说明的是,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、商品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、商品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、商品或者设备中还存在另外的相同要素。
本领域技术人员应明白,本说明书的实施例可提供为方法、系统或计算机程序产品。因此,本说明书可采用完全硬件实施例、完全软件实施例或结合软件和硬件方面的实施例的形式。而且,本说明书可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本说明书可以在由计算机执行的计算机可执行指令的一般上下文中描述,例如程序模块。一般地,程序模块包括执行特定任务或实现特定抽象数据类型的例程、程序、对象、组件、数据结构等等。也可以在分布式计算环境中实践本说明书,在这些分布式计算环境中,由通过通信网络而被连接的远程处理设备来执行任务。在分布式计算环境中,程序模块可以位于包括存储设备在内的本地和远程计算机存储介质中。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于系统实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本说明书的实施例而已,并不用于限制本说明书。对于本领域技术人员来说,本说明书可以有各种更改和变化。凡在本说明书的精神和原理之内所作的任何修改、等同替换、改进等,均应包含在本申请的权利要求范围之内。

Claims (10)

1.一种基于联邦多任务学习的模型训练方法,其特征在于,所述方法应用于分布式系统中的中心服务器,所述方法包括:
向各客户端发送所述各客户端各自对应的初始模型参数,以使所述各客户端分别对根据各自对应的初始模型参数得到的待训练模型进行训练,得到所述各客户端对应的待训练模型的优化模型参数;
接收所述各客户端发送的所述各客户端对应的优化模型参数;
针对每个客户端,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重;
根据确定出的该客户端对其他客户端对应的优化模型参数进行加权的权重以及所述各客户端对应的优化模型参数,确定该客户端对应的加权优化参数;
将确定出的各客户端对应的加权优化参数重新作为所述各客户端各自对应的初始模型参数,并分别发送至所述各客户端,以使所述各客户端继续对所述各客户端各自对应的待训练模型进行训练。
2.如权利要求1所述的方法,其特征在于,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重,具体包括:
确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度;
根据确定出的该客户端对应的各相似度,确定该客户端对其他客户端对应的优化模型参数进行加权的权重。
3.如权利要求2所述的方法,其特征在于,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度,具体包括:
确定该客户端对应的优化模型参数的参数向量与其他客户端对应的优化模型参数的参数向量之间的欧氏距离矩阵;
根据确定出的该客户端对应的欧氏距离矩阵,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度。
4.如权利要求2所述的方法,其特征在于,确定该客户端对其他客户端对应的优化模型参数进行加权的权重,具体包括:
对该客户端对应的各相似度进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重。
5.如权利要求4所述的方法,其特征在于,对该客户端对应的各相似度进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重,具体包括:
使用softmax函数对该客户端对应的各相似度进行归一化,得到所述softmax函数输出各相似度的概率;
将所述softmax函数输出各相似度的概率作为该客户端对其他客户端对应的优化模型参数进行加权的权重。
6.一种基于联邦多任务学习的模型训练装置,其特征在于,所述装置应用于分布式系统中的中心服务器,所述装置具体包括:
发送模块,用于向各客户端发送所述各客户端各自对应的初始模型参数,以使所述各客户端分别对根据各自对应的初始模型参数得到的待训练模型进行训练,得到所述各客户端对应的待训练模型的优化模型参数;
接收模块,用于接收所述各客户端发送的所述各客户端对应的优化模型参数;
确定模块,用于针对每个客户端,根据该客户端对应的优化模型参数,确定该客户端对其他客户端对应的优化模型参数进行加权的权重;
加权模块,用于根据确定出的该客户端对其他客户端对应的优化模型参数进行加权的权重以及所述各客户端对应的优化模型参数,确定该客户端对应的加权优化参数;
训练模块,用于将确定出的各客户端对应的加权优化参数重新作为所述各客户端各自对应的初始模型参数,并分别发送至所述各客户端,以使所述各客户端继续对所述各客户端各自对应的待训练模型进行训练。
7.如权利要求6所述的装置,其特征在于,所述确定模块具体用于,确定该客户端对应的优化模型参数与其他客户端对应的优化模型参数之间的各相似度;根据确定出的该客户端对应的各相似度,确定该客户端对其他客户端对应的优化模型参数进行加权的权重。
8.如权利要求6所述的装置,其特征在于,所述确定模块具体用于,对该客户端对应的各相似度进行归一化,得到该客户端对其他客户端对应的优化模型参数进行加权的权重。
9.一种计算机可读存储介质,其特征在于,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述权利要求1-5任一所述的方法。
10.一种电子设备,其特征在于,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述权利要求1-5任一所述的方法。
CN202311298511.1A 2023-10-09 2023-10-09 一种基于联邦多任务学习的模型训练方法、装置及设备 Pending CN117057442A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311298511.1A CN117057442A (zh) 2023-10-09 2023-10-09 一种基于联邦多任务学习的模型训练方法、装置及设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311298511.1A CN117057442A (zh) 2023-10-09 2023-10-09 一种基于联邦多任务学习的模型训练方法、装置及设备

Publications (1)

Publication Number Publication Date
CN117057442A true CN117057442A (zh) 2023-11-14

Family

ID=88661163

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311298511.1A Pending CN117057442A (zh) 2023-10-09 2023-10-09 一种基于联邦多任务学习的模型训练方法、装置及设备

Country Status (1)

Country Link
CN (1) CN117057442A (zh)

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115344883A (zh) * 2022-06-29 2022-11-15 上海工程技术大学 一种用于处理不平衡数据的个性化联邦学习方法和装置
CN115840900A (zh) * 2022-09-16 2023-03-24 河海大学 一种基于自适应聚类分层的个性化联邦学习方法及系统
CN116205311A (zh) * 2023-02-16 2023-06-02 同济大学 一种基于Shapley值的联邦学习方法
CN116542296A (zh) * 2023-05-04 2023-08-04 北京芯联心科技发展有限公司 基于联邦学习的模型训练方法、装置及电子设备
CN116680565A (zh) * 2023-05-29 2023-09-01 新奥新智科技有限公司 一种联合学习模型训练方法、装置、设备及存储介质

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115344883A (zh) * 2022-06-29 2022-11-15 上海工程技术大学 一种用于处理不平衡数据的个性化联邦学习方法和装置
CN115840900A (zh) * 2022-09-16 2023-03-24 河海大学 一种基于自适应聚类分层的个性化联邦学习方法及系统
CN116205311A (zh) * 2023-02-16 2023-06-02 同济大学 一种基于Shapley值的联邦学习方法
CN116542296A (zh) * 2023-05-04 2023-08-04 北京芯联心科技发展有限公司 基于联邦学习的模型训练方法、装置及电子设备
CN116680565A (zh) * 2023-05-29 2023-09-01 新奥新智科技有限公司 一种联合学习模型训练方法、装置、设备及存储介质

Similar Documents

Publication Publication Date Title
US11880754B2 (en) Electronic apparatus and control method thereof
CN109214193B (zh) 数据加密、机器学习模型训练方法、装置以及电子设备
CN115981870B (zh) 一种数据处理的方法、装置、存储介质及电子设备
US20200167527A1 (en) Method, device, and apparatus for word vector processing based on clusters
CN116049761A (zh) 数据处理方法、装置及设备
CN116912923B (zh) 一种图像识别模型训练方法和装置
CN116342888B (zh) 一种基于稀疏标注训练分割模型的方法及装置
CN116091895A (zh) 一种面向多任务知识融合的模型训练方法及装置
CN116630480A (zh) 一种交互式文本驱动图像编辑的方法、装置和电子设备
CN117057442A (zh) 一种基于联邦多任务学习的模型训练方法、装置及设备
CN116824331A (zh) 一种模型训练、图像识别方法、装置、设备及存储介质
CN116501852B (zh) 一种可控对话模型训练方法、装置、存储介质及电子设备
CN117911630B (zh) 一种三维人体建模的方法、装置、存储介质及电子设备
CN115841335B (zh) 数据处理方法、装置及设备
CN113887326B (zh) 一种人脸图像处理方法及装置
CN117332282B (zh) 一种基于知识图谱的事件匹配的方法及装置
CN116109008B (zh) 一种业务执行的方法、装置、存储介质及电子设备
CN117077817B (zh) 一种基于标签分布的个性化联邦学习模型训练方法及装置
CN117132806A (zh) 一种模型训练的方法、装置、存储介质及电子设备
CN117036830B (zh) 一种肿瘤分类模型训练方法、装置、存储介质及电子设备
CN116186540A (zh) 数据处理方法、装置及设备
CN116386894A (zh) 一种信息溯源方法、装置、存储介质及电子设备
CN116246774A (zh) 一种基于信息融合的分类方法、装置及设备
CN116702131A (zh) 一种数据处理方法、装置及设备
CN117593004A (zh) 数据处理方法、装置及设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination