CN114627202A - 一种基于特异性联邦学习的模型训练方法及装置 - Google Patents
一种基于特异性联邦学习的模型训练方法及装置 Download PDFInfo
- Publication number
- CN114627202A CN114627202A CN202210212867.8A CN202210212867A CN114627202A CN 114627202 A CN114627202 A CN 114627202A CN 202210212867 A CN202210212867 A CN 202210212867A CN 114627202 A CN114627202 A CN 114627202A
- Authority
- CN
- China
- Prior art keywords
- model
- global
- training
- trained
- global sharing
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T11/00—2D [Two Dimensional] image generation
- G06T11/003—Reconstruction from projections, e.g. tomography
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- H—ELECTRICITY
- H04—ELECTRIC COMMUNICATION TECHNIQUE
- H04L—TRANSMISSION OF DIGITAL INFORMATION, e.g. TELEGRAPHIC COMMUNICATION
- H04L67/00—Network arrangements or protocols for supporting network services or applications
- H04L67/01—Protocols
- H04L67/10—Protocols in which an application is distributed across nodes in the network
- H04L67/1097—Protocols in which an application is distributed across nodes in the network for distributed storage of data in networks, e.g. transport arrangements for network file system [NFS], storage area networks [SAN] or network attached storage [NAS]
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Data Mining & Analysis (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Computer Networks & Wireless Communication (AREA)
- Signal Processing (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请提供了一种基于特异性联邦学习的模型训练方法及装置,方法包括:在每一轮通信中,服务器端将全局共享模型发送至每个客户端,每个客户端根据服务器端当前传输的全局共享模型进行局部梯度更新,本地更新完成后,客户端参与服务器端的全局梯度更新,并将更新结果返回至服务器端,服务器端根据客户端返回的更新结果确定下一轮的全局共享模型,并且从第二轮起引入加权对比正则化对客户端的局部梯度更新进行校正;经过多轮通信后,客户端逐渐具有全局共享模型的特征。本申请可以在满足隐私保护机制的同时缓解客户端在训练过程中的域漂移,促进收敛。
Description
技术领域
本申请涉及图像处理技术领域,特别是一种基于特异性联邦学习的模型训练方法及装置。
背景技术
磁共振(Magnetic Resonance,MR)成像已经成为放射学和医学的主流诊断工具。然而,其复杂的成像过程导致比其他方法(如计算机断层扫描(Computed Tomography,CT)、x射线和超声)需要花费更长的采集时间。为了减少扫描时间和改善患者体验,目前已经提出了几种速磁共振成像方法,例如:基于压缩感知的传统方法、字典学习、低秩等。近年来,数据驱动的深度学习方法在MR图像重建方面也有了显著的改进,这主要得益于大量可用的训练数据。然而,基于深度学习的方法所获得的优越结果往往依赖于大量多样的配对数据,而实际上这些数据由于患者隐私问题而难以收集。
最近,有人提出了联邦学习(Federated Learning,FL)算法,它为不同的客户端提供了一个平台,可以使用本地计算力、内存和数据进行协作学习,而不共享任何私有的本地数据。FedAvg是标准且最广泛使用的FL算法之一,它在每一轮通信中收集每个客户端的本地模型,并将它们的平均值分发给每个客户端以备下次更新。由于采用分布式联邦训练,FL在许多领域得到了应用,包括图像分类、目标检测、域泛化、医学图像分割等。然而,在MR图像重建中,不同医院的不同磁共振扫描仪和成像协议存在异质性,导致客户端之间存在域移位。不幸的是,在这些条件下,使用FL训练的模型进行简单的联邦训练仍可能是次优的。技术人员试图通过反复调整和对齐源和目标客户端之间的潜在特征来解决这个问题,这是第一次尝试在MR图像重建中使用FL。
虽然FL已经被应用于MR图像重建,但这种跨站点的方法往往需要牺牲一个客户端作为目标位置,以便在每一轮交流中与其他客户端对齐。很明显,任何被用作目标站点的客户端都将导致隐私泄露的问题,并且跨站点的方法与FL的目的相矛盾,后者是防止客户端通过本地数据相互通信。此外,当客户端数量较大时,由于重复训练和频繁的特征交流,使该过程变得繁琐。更重要的是,这种机制只能学习通用的全局模型,而忽略了各个客户端的特定属性,之前在域自适应方面的研究也表明,编码器通常用来学习共享表示,以确保所有输入都同样适合于任何域变换。因此,虽然FL算法在MR图像重建上做了初步尝试,但其精度仍有待提高。
发明内容
鉴于所述问题,提出了本申请以便提供克服所述问题或者至少部分地解决所述问题的一种基于特异性联邦学习的模型训练方法及装置,包括:
一种基于特异性联邦学习的模型训练方法,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练方法针对所述服务器端;所述模型训练方法包括:
所述服务器端发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端;
所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型;
当所述上一轮训练完成的全局共享模型集合非空时,所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
优选的,所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型的步骤之后,还包括:
当所述上一轮训练完成的全局共享模型集合为空时,所述服务器端依据所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
优选的,所述训练完成的全局共享模型集合包含全部训练完成的全局共享模型;所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合的步骤,包括:
对于每个所述当前初步训练完成的全局共享模型,所述服务器端执行如下步骤:
所述服务器端依据所述全局共享模型对所述训练数据进行处理,获得第一预测结果;
所述服务器端依据所述上一轮训练完成的全局共享模型集合对所述训练数据进行处理,获得第二预测结果集合;
所述服务器端依据所述初步训练完成的全局共享模型对所述训练数据进行处理,获得第三预测结果;
所述服务器端依据所述第一预测结果、所述第二预测结果集合、所述第三预测结果和预先构建的加权对比正则化损失函数确定第一损失值;
所述服务器端依据所述第三预测结果和预先构建的监督重建损失函数确定第二损失值;
所述服务器端依据所述第一损失值和所述第二损失值对所述初步训练完成的全局共享模型进行训练,获得所述训练完成的全局共享模型。
优选的,所述训练完成的全局共享模型集合包含全部训练完成的全局共享模型;所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型的步骤,包括:
所述服务器端将全部所述训练完成的全局共享模型的平均值设置为所述全局共享模型。
一种基于特异性联邦学习的模型训练方法,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练方法针对所述至少两个客户端中的任意一个客户端;所述模型训练方法包括:
所述客户端接收所述服务器端发送的所述全局共享模型;
所述客户端依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
所述客户端依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
所述客户端将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
优选的,所述客户端依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型的步骤,包括:
所述客户端依据所述全局共享模型对所述本地数据进行处理,获得第四预测结果;
所述客户端依据所述本地模型对所述本地数据进行处理,获得第五预测结果;
所述客户端依据所述第四预测结果、所述第五预测结果和预先构建的本地损失函数,确定第三损失值;
所述客户端依据所述第三损失值对所述本地模型进行训练,获得所述训练完成的本地模型。
优选的,所述客户端依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型的步骤,包括:
所述客户端依据所述训练完成的本地模型对所述本地数据进行处理,获得第六预测结果;
所述客户端依据所述全局共享模型对所述本地数据进行处理,获得第七预测结果;
所述客户端依据所述第六预测结果、所述第七预测结果和预先构建的共享损失函数,确定第四损失值;
所述客户端依据所述第四损失值对所述全局共享模型进行训练,获得所述训练完成的全局共享模型。
一种基于特异性联邦学习的模型训练装置,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练装置针对所述服务器端;所述模型训练装置包括:
全局共享模型发送模块,用于发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端;
初级模型接收模块,用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;
初级模型训练模块,用于当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
全局模型确定模块,用于依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
一种基于特异性联邦学习的模型训练装置,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练装置针对所述至少两个客户端中的任意一个客户端;所述模型训练装置包括:
全局共享模型接收模块,用于接收所述服务器端发送的所述全局共享模型;
本地模型训练模块,用于依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
全局共享模型训练模块,用于依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
初级模型发送模块,用于将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
一种机器学习系统,包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;
所述服务器端,用于发送所述全局共享模型至每个所述客户端;
所述客户端,用于接收所述服务器端发送的所述全局共享模型;
所述客户端,还用于依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
所述客户端,还用于依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
所述客户端,还用于将所述初步训练完成的全局共享模型发送至所述服务器端;
所述服务器端,还用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;
所述服务器端,还用于当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
所述服务器端,还用于依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
本申请具有以下优点:
在本申请的实施例中,通过所述服务器端发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型,可以在满足隐私保护机制的同时缓解客户端在训练过程中的域漂移,促进收敛。
附图说明
为了更清楚地说明本申请的技术方案,下面将对本申请的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一实施例提供的一种基于特异性联邦学习的模型训练方法的框架概述示意图;
图2是本申请一实施例提供的一种基于特异性联邦学习的模型训练方法的步骤流程图;
图3是本申请一实施例提供的一种基于特异性联邦学习的模型训练方法的步骤流程图;
图4是本申请一实施例提供的一种基于特异性联邦学习的模型训练方法的步骤流程图;
图5是本申请一实施例提供的一种基于fastMRI、Brats、SMS和uMR数据集的潜在特征T-SNE的可视化示意图。
图6是本申请一实施例提供的一种基于特异性联邦学习的模型训练装置的结构框图;
图7是本申请一实施例提供的一种基于特异性联邦学习的模型训练装置的结构框图;
图8是本申请一实施例提供的一种计算机设备的结构示意图。
说明书附图中的附图标记如下:
12、计算机设备;14、外部设备;16、处理单元;18、总线;20、网络适配器;22、I/O接口;24、显示器;28、内存;30、随机存取存储器;32、高速缓存存储器;34、存储系统;40、程序/实用工具;42、程序模块。
具体实施方式
为使本申请的所述目的、特征和优点能够更加明显易懂,下面结合附图和具体实施方式对本申请作进一步详细的说明。显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
参照图1,,针对域漂移影响下多机构联邦重建精度不高的问题,本申请将MR图像重建模型分为两部分,一个存储在服务器端的全局共享模型用于学习广义表示,以及一个存储在客户端的本地模型用于探索客户端域分布的独特性。此外,为了减少服务器端和客户端之间的偏移量,本申请还引入了一个加权对比正则化函数来校正全局泛化的更新方向,具体来说,将客户端初步训练完成的全局共享模型(锚点)拉向全局共享模型(正点),并将其推离上一轮训练完成的全局共享模型集合(负点)。本申请可以可以在满足隐私保护机制的同时缓解客户端在训练过程中的域漂移,促进收敛,实现模型性能的显著改善。
参照图2,示出了本申请一实施例提供的一种基于特异性联邦学习的模型训练方法,所述模型训练方法用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练方法针对所述服务器端;所述模型训练方法包括:
S110、所述服务器端发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端;
S120、所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型;
S130、当所述上一轮训练完成的全局共享模型集合非空时,所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
S140、所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
在本申请的实施例中,通过所述服务器端发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型,可以在满足隐私保护机制的同时缓解客户端在训练过程中的域漂移,促进收敛。
下面,将对本示例性实施例中一种基于特异性联邦学习的模型训练方法作进一步地说明。
如所述步骤S110所述,所述服务器端发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端。
需要说明的是,所述服务器端可以循环执行如S110-S140所述的训练步骤,也即将每一轮训练输出得到的所述全局共享模型作为下一轮训练输入的所述全局共享模型。存储于所述服务器端的所述全局共享模型和所述上一轮训练完成的全局共享模型集合,以及存储于所述客户端的所述本地模型在每一轮的训练过程中均得到更新。本申请中涉及的模型可以是神经网络模型,例如卷积神经网络模型、循环神经网络模型、深度残差网络模型等等。本申请对所涉及的模型的具体类别不作限定。
所述服务器端发送所述全局共享模型至每个所述客户端,可以理解为所述服务器端发送的是完整的所述全局共享模型,也可以理解为所述服务器端发送的是所述全局共享模型的全部权重参数,或者可以理解为所述服务器端发送的是所述全局共享模型的部分权重参数,所述部分权重参数指所述全局共享模型相比于上一轮全局共享模型模型有更新的权重参数。
如所述步骤S120所述,所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型。
由于不同所述客户端对数据处理的速度不同,所述服务器端可以在接收到全部所述客户端发送的所述初步训练完成的全局共享模型后再执行S130的步骤,也可以按照接收次序对每个所述客户端发送的所述初步训练完成的全局共享模型分别进行处理,并在对全部所述初步训练完成的全局共享模型处理完成后执行S140的步骤。
所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型,可以理解为所述服务器端接收的是完整的所述初步训练完成的全局共享模型,也可以理解为所述服务器端接收的是所述初步训练完成的全局共享模型的全部权重参数,或者可以理解为所述服务器端接收的是所述初步训练完成的全局共享模型的部分权重参数,所述部分权重参数指所述初步训练完成的全局共享模型相比于所述全局共享模型有更新的权重参数。
如所述步骤S130所述,当所述上一轮训练完成的全局共享模型集合非空时,所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合。
当所述上一轮训练完成的全局共享模型集合非空时,也即从第二轮训练起,所述服务器端通过所述训练数据、预先构建的监督重建损失函数和加权对比正则化损失函数对每个所述初步训练完成的全局共享模型进行训练,获得对应于每一个所述初步训练完成的全局共享模型的训练完成的全局共享模型,形成包含全部所述训练完成的全局共享模型的所述训练完成的全局共享模型集合。
由于每一轮训练均需要在客户端更新和服务器端更新之间交替进行。考虑将所述MR图像重建模型划分为存储于所述服务器端的所述全局共享模型和存储于第k个所述客户端的所述本地模型,以共享全局信息和寻找唯一的深度信息。监督重建损失函数可以表达为:
其中,Ge和分别表示所述全局共享模型和所述本地模型,x∈CM表示欠采样图像,y表示全采样图像,x、y组成预存于所述服务器端的所述训练数据,K表示所述客户端的总数。需要注意的是,所述全局共享模型是由所述服务器端和所述客户端共同学习的。虽然所述客户端已经共享了所述初步训练完成的全局共享模型给所述服务器端来寻找多个所述客户端之间的通用表示,但是在迭代优化过程中,所述全局共享模型与所述初步训练完成的全局共享模型之间总是存在偏移的,主要是由局部优化过程中的域移位引起的。为了进一步修正局部更新,使模型具有全局识别能力,本申请在所述全局共享模型和所述本地模型之间引入加权对比正则化,迫使所述全局共享模型学习更强的广义表示。与传统的对比学习不同,本申请不需要从数据中寻找正负对,而是直接对网络参数的更新方向进行正则化。这使得梯度更新可以更直接地进行校正,而不需要在每次迭代中依赖大的训练样本数。
假设第k个所述客户端执行一个本地更新,它首先从所述服务器端接收所述全局共享模型,然后基于这些数据执行一个本地迭代更新。然而,来自所述服务器端的全局参数总是比本地参数有更小的偏差。本申请将所述加权对比正则化损失函数定义为:
结合所述监督重建损失函数,模型的整体损失函数可表示为:
其中,μ是控制所述加权对比正则化损失函数的权重的超参数。
在获得所述训练完成的全局共享模型集合之后,所述服务器端还依据所述训练完成的全局共享模型集合对所述上一轮训练完成的全局共享模型集合进行更新,以确保下一轮通信开始前存储于所述服务器端的所述上一轮训练完成的全局共享模型集合是本轮通信过程中获得的所述训练完成的全局共享模型集合。
如所述步骤S140所述,所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
所述服务器端可以采用多种融合算法对所述训练完成的全局共享模型集合中包含的多个所述训练完成的全局共享模型进行融合处理,以更新所述全局共享模型。例如可以对多个所述训练完成的全局共享模型求平均值以更新所述全局共享模型,或者可以对多个所述训练完成的全局共享模型进行加权处理,以更新所述全局共享模型,或者可以采用其他预设的算法对多个所述训练完成的全局共享模型进行处理,以更新所述全局共享模型。
参照图3,在本申请一实施例中,所述步骤S120之后,还包括:
S210、当所述上一轮训练完成的全局共享模型集合为空时,所述服务器端依据所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
S220、所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
如所述步骤S210所述,当所述上一轮训练完成的全局共享模型集合为空时,所述服务器端依据所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合。
当所述上一轮训练完成的全局共享模型集合为空时,也即为第一轮训练时,所述服务器端通过所述训练数据和所述监督重建损失函数对每个所述初步训练完成的全局共享模型进行训练,获得对应于每一个所述初步训练完成的全局共享模型的训练完成的全局共享模型,并形成包含全部所述训练完成的全局共享模型的所述训练完成的全局共享模型集合。
如所述步骤S220所述,所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
所述服务器端可以采用多种融合算法对所述训练完成的全局共享模型集合中包含的多个所述训练完成的全局共享模型进行融合处理,以更新所述全局共享模型。例如可以对多个所述训练完成的全局共享模型求平均值以更新所述全局共享模型,或者可以对多个所述训练完成的全局共享模型进行加权处理,以更新所述全局共享模型,或者可以采用其他预设的算法对多个所述训练完成的全局共享模型进行处理,以更新所述全局共享模型。
本实施例中,所述训练完成的全局共享模型集合包含全部训练完成的全局共享模型;所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合的步骤,包括:
对于每个所述当前初步训练完成的全局共享模型,所述服务器端执行如下步骤:
所述服务器端依据所述全局共享模型对所述训练数据进行处理,获得第一预测结果;
所述服务器端依据所述上一轮训练完成的全局共享模型集合对所述训练数据进行处理,获得第二预测结果集合;
所述服务器端依据所述初步训练完成的全局共享模型对所述训练数据进行处理,获得第三预测结果;
所述服务器端依据所述第一预测结果、所述第二预测结果集合、所述第三预测结果和预先构建的加权对比正则化损失函数确定第一损失值;
所述服务器端依据所述第三预测结果和预先构建的监督重建损失函数确定第二损失值;
所述服务器端依据所述第一损失值和所述第二损失值对所述初步训练完成的全局共享模型进行训练,获得所述训练完成的全局共享模型。
具体地,所述服务器端依据所述第一损失值和所述第二损失值确定第一总体损失值,依据所述第一总体损失值对所述初步训练完成的全局共享模型进行训练,直至所述第一总体损失值小于第一预设值时停止训练,获得所述训练完成的全局共享模型。
参照图4,示出了本申请一实施例提供的一种基于特异性联邦学习的模型训练方法,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练方法针对所述至少两个客户端中的任意一个客户端;所述模型训练方法包括:
S310、所述客户端接收所述服务器端发送的所述全局共享模型;
S320、所述客户端依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
S330、所述客户端依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
S340、所述客户端将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
如所述步骤S310所述,所述客户端接收所述服务器端发送的所述全局共享模型。
所述客户端接收所述服务器端发送的所述全局共享模型,可以理解为所述客户端接收的是完整的所述全局共享模型,也可以理解为所述客户端接收的是所述全局共享模型的全部权重参数,或者可以理解为所述客户端接收的是所述全局共享模型的部分权重参数,所述部分权重参数指所述全局共享模型相比于上一轮全局共享模型模型有更新的权重参数。
如所述步骤S320所述,所述客户端依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型。
所述客户端根据所述服务器端发送的所述全局共享模型更新局部梯度,以找到最优的唯一局部信息,如下所示:
这个更新规则的优点是可以控制本地更新的数量,从而根据本地数据找到最优的特定于所述客户端的所述训练完成的本地模型。
在获得所述训练完成的本地模型之后,所述客户端还依据所述训练完成的本地模型对所述本地模型进行更新,以确保下一轮通信开始前存储于所述客户端的所述本地模型是本轮通信过程中获得的所述训练完成的本地模型。
如所述步骤S330所述,所述客户端依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型。
本地更新完成后,所述客户端参与全局梯度更新,如下所示:
其中,Lse为预先构建的共享损失函数。
如所述步骤S340所述,所述客户端将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
所述客户端将所述初步训练完成的全局共享模型发送至所述服务器端,可以理解为所述客户端发送的是完整的所述初步训练完成的全局共享模型,也可以理解为所述客户端发送的是所述初步训练完成的全局共享模型的全部权重参数,或者可以理解为所述客户端发送的是所述初步训练完成的全局共享模型的部分权重参数,所述部分权重参数指所述初步训练完成的全局共享模型相比于所述全局共享模型有更新的权重参数。
本实施例中,所述客户端依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型的步骤,包括:
所述客户端依据所述全局共享模型对所述本地数据进行处理,获得第四预测结果;
所述客户端依据所述本地模型对所述本地数据进行处理,获得第五预测结果;
所述客户端依据所述第四预测结果、所述第五预测结果和预先构建的本地损失函数,确定第三损失值;
所述客户端依据所述第三损失值对所述本地模型进行训练,获得所述训练完成的本地模型。
具体地,所述客户端依据所述第三损失值对所述本地模型进行训练,直至所述第三损失值小于第二预设值时停止训练,获得所述训练完成的本地模型。
本实施例中,所述客户端依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型的步骤,包括:
所述客户端依据所述训练完成的本地模型对所述本地数据进行处理,获得第六预测结果;
所述客户端依据所述全局共享模型对所述本地数据进行处理,获得第七预测结果;
所述客户端依据所述第六预测结果、所述第七预测结果和预先构建的共享损失函数,确定第四损失值;
所述客户端依据所述第四损失值对所述全局共享模型进行训练,获得所述训练完成的全局共享模型。
具体地,所述客户端依据所述第四损失值对所述全局共享模型进行训练,直至所述第四损失值小于第三预设值时停止训练,获得所述训练完成的全局共享模型。
本申请的具体步骤参见如下算法。在每一轮通信中,所述服务器端将所述全局共享模型发送给每个所述客户端。然后,每个所述客户端根据所述全局共享模型进行局部梯度更新,获取其最优的唯一信息,如公式(2)所示。然后,所述客户端根据公式(3)参与所述服务器端的更新,随后所述服务器端根据公式(5)对局部梯度更新进行校正。
输入:K个客户端数据:D1,D2,…,Dk;所述客户端的更次次数T;通信轮数Z;超参数μ;对应于每个客户端的学习率ηk;
算法1
参照图5,为验证本申请提供的所述模型训练方法的性能,将潜在特征的T-SNE分布可视化,其中,(a-d)分别显示了SingleSet、FedAvg、不包括Lcon的FedMRI算法以及本申请的算法。在SingleSet中,每个客户端的训练只使用它们的本地数据。(a)中点的分布有明显区别,因为每个数据集都有自己的偏差,而(b)、(c)和(d)中的数据有不同程度的重叠,因为这些模型受益于FL的联邦训练机制。然而,对于分布差异较大的数据集,如fastMRI和BraTS,FedAvg几乎是失败的(见图5(b))。
值得注意的是,即使没有Lcon,本申请的方法仍然可以对齐四个不同的数据集上潜在的空间分布,这表明共享一个全局共享模型和保持一个客户特定的本地模型可以有效地减少域移位问题(见图5(c))。图5(d)显示了不同客户端的潜在特征分布明显完全混合。这可以归因于加权对比规则化使得本申请的算法能够在优化期间有效地纠正客户端和服务器端之间的偏差(见图5(d))。
在本申请一实施例中,还提供一种基于特异性联邦学习的图像处理方法,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端,所述图像重建方法针对所述服务器端,所述图像处理方法包括:
获取待处理数据;
依据基于上述任一所述的模型训练方法训练得到的所述全局共享模型对所述待处理数据进行处理,获得所述待处理数据的图像重建结果。
对于装置实施例而言,由于其与方法实施例基本相似,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
参照图6,示出了本申请一实施例提供的一种基于特异性联邦学习的模型训练装置,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练装置针对所述服务器端;所述模型训练装置包括:
全局共享模型发送模块410,用于发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端;
初级模型接收模块420,用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;
初级模型训练模块430,用于当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
全局模型确定模块440,用于依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
参照图7,示出了本申请一实施例提供的一种基于特异性联邦学习的模型训练装置,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练装置针对所述至少两个客户端中的任意一个客户端;所述模型训练装置包括:
全局共享模型接收模块510,用于接收所述服务器端发送的所述全局共享模型;
本地模型训练模块520,用于依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
全局共享模型训练模块530,用于依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
初级模型发送模块540,用于将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
在本申请一实施例中,还提供一种机器学习系统,包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;
所述服务器端,用于发送所述全局共享模型至每个所述客户端;
所述客户端,用于接收所述服务器端发送的所述全局共享模型;
所述客户端,还用于依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
所述客户端,还用于依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
所述客户端,还用于将所述初步训练完成的全局共享模型发送至所述服务器端;
所述服务器端,还用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;
所述服务器端,还用于当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
所述服务器端,还用于依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
参照图8,示出了本申请的一种基于特异性联邦学习的模型训练方法的计算机设备,具体可以包括如下:
上述计算机设备12以通用计算设备的形式表现,计算机设备12的组件可以包括但不限于:一个或者多个处理器或者处理单元16,内存28,连接不同系统组件(包括内存28和处理单元16)的总线18。
总线18表示几类总线18结构中的一种或多种,包括存储器总线18或者存储器控制器,外围总线18,图形加速端口,处理器或者使用多种总线18结构中的任意总线18结构的局域总线18。举例来说,这些体系结构包括但不限于工业标准体系结构(ISA)总线18,微通道体系结构(MAC)总线18,增强型ISA总线18、音视频电子标准协会(VESA)局域总线18以及外围组件互连(PCI)总线18。
计算机设备12典型地包括多种计算机系统可读介质。这些介质可以是任何能够被计算机设备12访问的可用介质,包括易失性和非易失性介质,可移动的和不可移动的介质。
内存28可以包括易失性存储器形式的计算机系统可读介质,例如随机存取存储器30和/或高速缓存存储器32。计算机设备12可以进一步包括其他移动/不可移动的、易失性/非易失性计算机体统存储介质。仅作为举例,存储系统34可以用于读写不可移动的、非易失性磁介质(通常称为“硬盘驱动器”)。尽管图8中未示出,可以提供用于对可移动非易失性磁盘(如“软盘”)读写的磁盘驱动器,以及对可移动非易失性光盘(例如CD-ROM,DVD-ROM或者其他光介质)读写的光盘驱动器。在这些情况下,每个驱动器可以通过一个或者多个数据介质界面与总线18相连。存储器可以包括至少一个程序产品,该程序产品具有一组(例如至少一个)程序模块42,这些程序模块42被配置以执行本申请各实施例的功能。
具有一组(至少一个)程序模块42的程序/实用工具40,可以存储在例如存储器中,这样的程序模块42包括——但不限于——操作系统、一个或者多个应用程序、其他程序模块42以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。程序模块42通常执行本申请所描述的实施例中的功能和/或方法。
计算机设备12也可以与一个或多个外部设备14(例如键盘、指向设备、显示器24、摄像头等)通信,还可与一个或者多个使得操作人员能与该计算机设备12交互的设备通信,和/或与使得该计算机设备12能与一个或多个其他计算设备进行通信的任何设备(例如网卡,调制解调器等等)通信。这种通信可以通过I/O接口22进行。并且,计算机设备12还可以通过网络适配器20与一个或者多个网络(例如局域网(LAN)),广域网(WAN)和/或公共网络(例如因特网)通信。如图8所示,网络适配器20通过总线18与计算机设备12的其他模块通信。应当明白,尽管图8中未示出,可以结合计算机设备12使用其他硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元16、外部磁盘驱动阵列、RAID系统、磁带驱动器以及数据备份存储系统34等。
处理单元16通过运行存储在内存28中的程序,从而执行各种功能应用以及数据处理,例如实现本申请实施例所提供的一种基于特异性联邦学习的模型训练方法。
也即,上述处理单元16执行上述程序时实现:发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型,依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型,依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型,将所述初步训练完成的全局共享模型发送至所述处理单元16;接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
在本申请一实施例中,还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如本申请所有实施例提供的一种基于特异性联邦学习的模型训练方法。
也即,给程序被处理器执行时实现:发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型,依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型,依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型,将所述初步训练完成的全局共享模型发送至所述计算机可读存储介质;接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
可以采用一个或多个计算机可读的介质的任意组合。计算机可读介质可以是计算机可读信号介质或者计算机可读存储介质。计算机可读存储介质例如可以是——但不限于——电、磁、光、电磁、红外线或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件或者上述的任意合适的组合。在本文件中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
计算机可读的信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码。这种传播的数据信号可以采用多种形式,包括——但不限于——电磁信号、光信号或上述的任意合适的组合。计算机可读的信号介质还可以是计算机可读存储介质以外的任何计算机可读介质,该计算机可读介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。
可以以一种或多种程序设计语言或其组合来编写用于执行本申请操作的计算机程序代码,上述程序设计语言包括面向对象的程序设计语言——诸如Java、Smalltalk、C++,还包括常规的过程式程序设计语言——诸如“C”语言或类似的程序设计语言。程序代码可以完全地在操作人员计算机上执行、部分地在操作人员计算机上执行、作为一个独立的软件包执行、部分在操作人员计算机上部分在远程计算机上执行或者完全在远程计算机或者服务器上执行。在涉及远程计算机的情形中,远程计算机可以通过任意种类的网络——包括局域网(LAN)或广域网(WAN)——连接到操作人员计算机,或者,可以连接到外部计算机(例如利用因特网服务提供商来通过因特网连接)。本说明书中的各个实施例均采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似的部分互相参见即可。
尽管已描述了本申请实施例的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例做出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本申请实施例范围的所有变更和修改。
最后,还需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者终端设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者终端设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者终端设备中还存在另外的相同要素。
以上对本申请所提供的一种基于特异性联邦学习的模型训练方法及装置,进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的一般技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。
Claims (10)
1.一种基于特异性联邦学习的模型训练方法,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练方法针对所述服务器端;其特征在于,所述模型训练方法包括:
所述服务器端发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端;
所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型;
当所述上一轮训练完成的全局共享模型集合非空时,所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
2.根据权利要求1所述的模型训练方法,其特征在于,所述服务器端接收每个所述客户端发送的所述初步训练完成的全局共享模型的步骤之后,还包括:
当所述上一轮训练完成的全局共享模型集合为空时,所述服务器端依据所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
3.根据权利要求1所述的模型训练方法,其特征在于,所述训练完成的全局共享模型集合包含全部训练完成的全局共享模型;所述服务器端依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合的步骤,包括:
对于每个所述初步训练完成的全局共享模型,所述服务器端执行如下步骤:
所述服务器端依据所述全局共享模型对所述训练数据进行处理,获得第一预测结果;
所述服务器端依据所述上一轮训练完成的全局共享模型集合对所述训练数据进行处理,获得第二预测结果集合;
所述服务器端依据所述初步训练完成的全局共享模型对所述训练数据进行处理,获得第三预测结果;
所述服务器端依据所述第一预测结果、所述第二预测结果集合、所述第三预测结果和预先构建的加权对比正则化损失函数确定第一损失值;
所述服务器端依据所述第三预测结果和预先构建的监督重建损失函数确定第二损失值;
所述服务器端依据所述第一损失值和所述第二损失值对所述初步训练完成的全局共享模型进行训练,获得所述训练完成的全局共享模型。
4.根据权利要求1所述的模型训练方法,其特征在于,所述训练完成的全局共享模型集合包含全部训练完成的全局共享模型;所述服务器端依据所述训练完成的全局共享模型集合,更新所述全局共享模型的步骤,包括:
所述服务器端将全部所述训练完成的全局共享模型的平均值设置为所述全局共享模型。
5.一种基于特异性联邦学习的模型训练方法,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练方法针对所述至少两个客户端中的任意一个客户端;其特征在于,所述模型训练方法包括:
所述客户端接收所述服务器端发送的所述全局共享模型;
所述客户端依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
所述客户端依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
所述客户端将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
6.根据权利要求5所述的模型训练方法,其特征在于,所述客户端依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型的步骤,包括:
所述客户端依据所述全局共享模型对所述本地数据进行处理,获得第四预测结果;
所述客户端依据所述本地模型对所述本地数据进行处理,获得第五预测结果;
所述客户端依据所述第四预测结果、所述第五预测结果和预先构建的本地损失函数,确定第三损失值;
所述客户端依据所述第三损失值对所述本地模型进行训练,获得所述训练完成的本地模型。
7.根据权利要求5所述的模型训练方法,其特征在于,所述客户端依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型的步骤,包括:
所述客户端依据所述训练完成的本地模型对所述本地数据进行处理,获得第六预测结果;
所述客户端依据所述全局共享模型对所述本地数据进行处理,获得第七预测结果;
所述客户端依据所述第六预测结果、所述第七预测结果和预先构建的共享损失函数,确定第四损失值;
所述客户端依据所述第四损失值对所述全局共享模型进行训练,获得所述训练完成的全局共享模型。
8.一种基于特异性联邦学习的模型训练装置,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练装置针对所述服务器端;其特征在于,所述模型训练装置包括:
全局共享模型发送模块,用于发送所述全局共享模型至每个所述客户端;所述客户端用于接收所述服务器端发送的所述全局共享模型;依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;将所述初步训练完成的全局共享模型发送至所述服务器端;
初级模型接收模块,用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;
初级模型训练模块,用于当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
全局模型确定模块,用于依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
9.一种基于特异性联邦学习的模型训练装置,用于机器学习系统,所述机器学习系统包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;所述模型训练装置针对所述至少两个客户端中的任意一个客户端;其特征在于,所述模型训练装置包括:
全局共享模型接收模块,用于接收所述服务器端发送的所述全局共享模型;
本地模型训练模块,用于依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
全局共享模型训练模块,用于依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
初级模型发送模块,用于将所述初步训练完成的全局共享模型发送至所述服务器端;所述服务器端用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
10.一种机器学习系统,其特征在于,包括服务器端和至少两个客户端;所述服务器端存储有全局共享模型、上一轮训练完成的全局共享模型集合和训练数据,对于第一轮训练,所述上一轮训练完成的全局共享模型集合为空集;每个所述客户端分别存储有本地模型和本地数据;
所述服务器端,用于发送所述全局共享模型至每个所述客户端;
所述客户端,用于接收所述服务器端发送的所述全局共享模型;
所述客户端,还用于依据所述全局共享模型和所述本地数据对所述本地模型进行训练,获得训练完成的本地模型;
所述客户端,还用于依据所述训练完成的本地模型和所述本地数据对所述全局共享模型进行训练,获得初步训练完成的全局共享模型;
所述客户端,还用于将所述初步训练完成的全局共享模型发送至所述服务器端;
所述服务器端,还用于接收每个所述客户端发送的所述初步训练完成的全局共享模型;
所述服务器端,还用于当所述上一轮训练完成的全局共享模型集合非空时,依据所述全局共享模型、所述上一轮训练完成的全局共享模型集合和所述训练数据对每个所述初步训练完成的全局共享模型进行训练,获得训练完成的全局共享模型集合;
所述服务器端,还用于依据所述训练完成的全局共享模型集合,更新所述全局共享模型。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210212867.8A CN114627202A (zh) | 2022-03-04 | 2022-03-04 | 一种基于特异性联邦学习的模型训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210212867.8A CN114627202A (zh) | 2022-03-04 | 2022-03-04 | 一种基于特异性联邦学习的模型训练方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114627202A true CN114627202A (zh) | 2022-06-14 |
Family
ID=81899597
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210212867.8A Pending CN114627202A (zh) | 2022-03-04 | 2022-03-04 | 一种基于特异性联邦学习的模型训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114627202A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116664456A (zh) * | 2023-08-02 | 2023-08-29 | 暨南大学 | 一种基于梯度信息的图片重建方法、系统及电子设备 |
CN116911403A (zh) * | 2023-06-06 | 2023-10-20 | 北京邮电大学 | 联邦学习的服务器和客户端的一体化训练方法及相关设备 |
-
2022
- 2022-03-04 CN CN202210212867.8A patent/CN114627202A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116911403A (zh) * | 2023-06-06 | 2023-10-20 | 北京邮电大学 | 联邦学习的服务器和客户端的一体化训练方法及相关设备 |
CN116911403B (zh) * | 2023-06-06 | 2024-04-26 | 北京邮电大学 | 联邦学习的服务器和客户端的一体化训练方法及相关设备 |
CN116664456A (zh) * | 2023-08-02 | 2023-08-29 | 暨南大学 | 一种基于梯度信息的图片重建方法、系统及电子设备 |
CN116664456B (zh) * | 2023-08-02 | 2023-11-17 | 暨南大学 | 一种基于梯度信息的图片重建方法、系统及电子设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
EP3511942B1 (en) | Cross-domain image analysis using deep image-to-image networks and adversarial networks | |
CN111008688B (zh) | 网络训练期间使用环路内数据增加的神经网络 | |
CN111178542B (zh) | 基于机器学习建模的系统和方法 | |
US10726555B2 (en) | Joint registration and segmentation of images using deep learning | |
US10755410B2 (en) | Method and apparatus for acquiring information | |
CN114627202A (zh) | 一种基于特异性联邦学习的模型训练方法及装置 | |
EP3872764B1 (en) | Method and apparatus for constructing map | |
US20180225823A1 (en) | Adversarial and Dual Inverse Deep Learning Networks for Medical Image Analysis | |
CN112308157B (zh) | 一种面向决策树的横向联邦学习方法 | |
WO2021114105A1 (zh) | 低剂量ct图像去噪网络的训练方法及系统 | |
US10929643B2 (en) | 3D image detection method and apparatus, electronic device, and computer readable medium | |
JP2021056995A (ja) | 医用情報処理装置、医用情報処理システム及び医用情報処理方法 | |
US11227689B2 (en) | Systems and methods for verifying medical diagnoses | |
US11430123B2 (en) | Sampling latent variables to generate multiple segmentations of an image | |
US20220414849A1 (en) | Image enhancement method and apparatus, and terminal device | |
CN109243600B (zh) | 用于输出信息的方法和装置 | |
CN112907439A (zh) | 一种基于深度学习的仰卧位和俯卧位乳腺图像配准方法 | |
CN113362314B (zh) | 医学图像识别方法、识别模型训练方法及装置 | |
CN111091010A (zh) | 相似度确定、网络训练、查找方法及装置和存储介质 | |
CN109961435B (zh) | 脑图像获取方法、装置、设备及存储介质 | |
WO2023216720A1 (zh) | 图像重建模型的训练方法、装置、设备、介质及程序产品 | |
US20220261985A1 (en) | System for determining the presence of features in a dataset | |
CN113673476A (zh) | 人脸识别模型训练方法、装置、存储介质与电子设备 | |
CN111209946B (zh) | 三维图像处理方法、图像处理模型训练方法及介质 | |
US20230274436A1 (en) | Automated Medical Image and Segmentation Quality Assessment for Machine Learning Tasks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |